# SNIP: Single-shot Network Pruning based on Connection Sensitivity, 19`ICLR
#### code mainly dependent on https://github.com/mil-ad/snip

#### CIFAR10 dataset, VGG-16 model (138 M)

In [1]:
import os
import copy
import types
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms

#from tensorboardX import SummaryWriter
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers import ProgressBar

from snip import snip_forward_conv2d, snip_forward_linear
from train import cifar10_experiment, apply_prune_mask

torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

training configuration

In [2]:
LOG_INTERVAL = 20
INIT_LR = 0.1
WEIGHT_DECAY_RATE = 0.0005
EPOCHS = 100  # originally 250
REPEAT_WITH_DIFFERENT_SEED = 1

BATCH_SIZE = 128
LR_DECAY_INTERVAL = 30000

## SNIP algorithm

1. 1개의 minibatch 샘플링
2. 모델 파라미터 (weight) variance scaling initialization
3. loss의 c에 대한 gradient를 구해서 connection sensitivity 계산
4. connection sensitivity sorting 후 pruning (top-k)
5. 프루닝된 모델 학습

In [3]:
def snip_forward_conv2d(self, x):
        return F.conv2d(x, self.weight * self.weight_mask, self.bias,
                        self.stride, self.padding, self.dilation, self.groups)


def snip_forward_linear(self, x):
        return F.linear(x, self.weight * self.weight_mask, self.bias)
    

def SNIP(net, keep_ratio, train_dataloader, device):
    inputs, targets = next(iter(train_dataloader))  # batch 1개 샘플링
    
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Let's create a fresh copy of the network so that we're not worried about
    # affecting the actual training-phase
    net = copy.deepcopy(net)

    # Monkey-patch the Linear and Conv2d layer to learn the multiplicative mask
    # instead of the weights
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))  # auxiliary variable c -> 학습할 수 있는 parameter
            nn.init.xavier_normal_(layer.weight)   # or kaiming_normal_()
            layer.weight.requires_grad = False   # weight에 대한 gradient는 필요 없음.

        # Override the forward methods:
        if isinstance(layer, nn.Conv2d):
            layer.forward = types.MethodType(snip_forward_conv2d, layer)

        if isinstance(layer, nn.Linear):
            layer.forward = types.MethodType(snip_forward_linear, layer)

    # Compute gradients (but don't apply them)
    net.zero_grad()
    outputs = net.forward(inputs)
    loss = F.nll_loss(outputs, targets)
    loss.backward()

    grads_abs = []
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            grads_abs.append(torch.abs(layer.weight_mask.grad))

    # Gather all scores in a single vector and normalise
    all_scores = torch.cat([torch.flatten(x) for x in grads_abs])
    norm_factor = torch.sum(all_scores)
    all_scores.div_(norm_factor)

    num_params_to_keep = int(len(all_scores) * keep_ratio)
    threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
    acceptable_score = threshold[-1]

    keep_masks = []
    for g in grads_abs:
        keep_masks.append(((g / norm_factor) >= acceptable_score).float())

    print(torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks])))

    return keep_masks

## Start pruning

In [5]:
net, optimiser, lr_scheduler, train_loader, val_loader = cifar10_experiment(device)

# Pre-training pruning using SNIP
t = time.time()
keep_masks = SNIP(net, 0.05, train_loader, device)  # 95% pruning
apply_prune_mask(net, keep_masks)
print('Pruning is done in {:.2f} sec.'.format(time.time()-t))

Files already downloaded and verified
tensor(761993, device='cuda:0')
Pruning is done in 3.16 sec.


## Start training

In [8]:
#writer = SummaryWriter()
trainer = create_supervised_trainer(net, optimiser, F.nll_loss, device)
evaluator = create_supervised_evaluator(net, {
    'accuracy': Accuracy(),
    'nll': Loss(F.nll_loss)
}, device)

pbar = ProgressBar()
pbar.attach(trainer)

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    lr_scheduler.step()
    iter_in_epoch = (engine.state.iteration - 1) % len(train_loader) + 1
#     if engine.state.iteration % LOG_INTERVAL == 0:
#         # pbar.log_message("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
#         #       "".format(engine.state.epoch, iter_in_epoch, len(train_loader), engine.state.output))
#         writer.add_scalar("training/loss", engine.state.output,
#                           engine.state.iteration)

@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch(engine):
    evaluator.run(val_loader)

    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_nll = metrics['nll']

    # pbar.log_message("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
    #       .format(engine.state.epoch, avg_accuracy, avg_nll))

#     writer.add_scalar("validation/loss", avg_nll, engine.state.iteration)
#     writer.add_scalar("validation/accuracy", avg_accuracy,
#                       engine.state.iteration)

In [None]:
trainer.run(train_loader, EPOCHS)