In [None]:
# code for paper "Pruning Convolutional Neural Networks for Resource Efficient Inference"
# code adopted from https://github.com/eeric/channel_prune
# which itself is adopted from https://github.com/jacobgil/pytorch-pruning

In [None]:
from pathlib import Path
import sys
from typing import Optional

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils import data

In [None]:
if sys.path[0] != '..':
    sys.path.insert(0, '..')
    
path_ros = '/opt/ros/kinetic/lib/python2.7/dist-packages'
if path_ros in sys.path:
    del sys.path[sys.path.index(path_ros)]
    
from networks.osvos_resnet import OSVOS_RESNET
from util import io_helper
from layers.osvos_layers import class_balanced_cross_entropy_loss

In [None]:
def get_net() -> nn.Module:
    net = OSVOS_RESNET(pretrained=False)
    path_model = Path('../models/resnet18_11_11_blackswan_epoch-9999.pth')
    parameters = torch.load(str(path_model), map_location=lambda storage, loc: storage)
    net.load_state_dict(parameters)
    net = net.cuda()
    return net

net = get_net()

In [None]:
def total_num_filters(net: nn.Module) -> int:
    n_filters = 0
    for m in net.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            n_filters += m.out_channels
    return n_filters

n_filters = total_num_filters(net)
n_filters_to_prune_per_iter = 512
n_iterations = int(n_filters / n_filters_to_prune_per_iter * 2 / 3)

print('Filters in model:', n_filters)
print('Prune n filters per iteration:', n_filters_to_prune_per_iter)
print('Number of iterations:', n_iterations)

In [None]:
class FilterPruner:
    def __init__(self, net: nn.Module):
        self.net = net
        self.reset()

    def reset(self) -> None:
        self.filter_ranks = {}
        # different from original code
        self.net.zero_grad()
        
    def forward(self) -> None:
        pass

pruner = FilterPruner(net)

In [None]:
data_loader = io_helper.get_data_loader_test(Path('/home/klaus/dev/datasets/DAVIS'), batch_size=1, seq_name='blackswan')

def train(pruner: FilterPruner, data_loader: data.DataLoader, n_epochs: Optional[int] = 1) -> None:
    for epoch in range(n_epochs):
        for minibatch_index, minibatch in enumerate(data_loader):
            inputs, gts = minibatch['image'], minibatch['gt']
            inputs, gts = Variable(inputs), Variable(gts)
            inputs, gts = inputs.cuda(), gts.cuda()
            
            outputs = pruner.forward(inputs)
            loss = class_balanced_cross_entropy_loss(outputs[-1], gts, size_average=False)
            loss.backward()

In [None]:
def get_candidates_to_prune(pruner: FilterPruner, n_filters_to_prune: int, net: nn.Module, data_loader: data.DataLoader) -> None:
    pruner.reset()
    train(pruner, data_loader)
    pruner.normalize_ranks_per_layer()
    return prunner.get_prunning_plan(n_filters_to_prune)

print('Ranking filters')
prune_targets = get_candidates_to_prune(pruner, n_filters_to_prune_per_iter, net, data_loader)
layers_prunned = {}
for layer_index, filter_index in prune_targets:
    if layer_index not in layers_prunned:
        layers_prunned[layer_index] = 0
    layers_prunned[layer_index] = layers_prunned[layer_index] + 1