## 导入相关包

In [7]:
import numpy as np
import os
import torch
from copy import deepcopy

import matplotlib.pyplot as plt
%matplotlib inline

In [162]:
ckpt_path = "./model_59.pth"
pretrained_model = torch.load(ckpt_path, map_location='cpu')

## 查看darknet19结构

<img src="./darknet19-p1.png">


<img src="./darknet19-p2.png">


<img src="./fpn_head.png">

In [57]:
for k,v in pretrained_model.items():
    print(k, v.shape)

module.fcos_body.backbone.block1.0.weight torch.Size([32, 3, 3, 3])
module.fcos_body.backbone.block1.1.weight torch.Size([32])
module.fcos_body.backbone.block1.1.bias torch.Size([32])
module.fcos_body.backbone.block1.1.running_mean torch.Size([32])
module.fcos_body.backbone.block1.1.running_var torch.Size([32])
module.fcos_body.backbone.block1.1.num_batches_tracked torch.Size([])
module.fcos_body.backbone.block1.4.weight torch.Size([64, 32, 3, 3])
module.fcos_body.backbone.block1.5.weight torch.Size([64])
module.fcos_body.backbone.block1.5.bias torch.Size([64])
module.fcos_body.backbone.block1.5.running_mean torch.Size([64])
module.fcos_body.backbone.block1.5.running_var torch.Size([64])
module.fcos_body.backbone.block1.5.num_batches_tracked torch.Size([])
module.fcos_body.backbone.block1.8.weight torch.Size([128, 64, 3, 3])
module.fcos_body.backbone.block1.9.weight torch.Size([128])
module.fcos_body.backbone.block1.9.bias torch.Size([128])
module.fcos_body.backbone.block1.9.running_me

In [2]:
relation_map = {}
relation_map['block1.0'] = {'pre':None}
relation_map['block1.4'] = {'pre':'block1.0'}
relation_map['block1.8'] = {'pre':'block1.4'}
relation_map['block1.11'] = {'pre':'block1.8'}
relation_map['block1.14'] = {'pre':'block1.11'}
relation_map['block1.18'] = {'pre':'block1.14'}
relation_map['block1.21'] = {'pre':'block1.18'}
relation_map['block1.24'] = {'pre':'block1.21'}
relation_map['block2.1'] = {'pre':'block1.24'}
relation_map['block2.4'] = {'pre':'block2.1'}
relation_map['block2.7'] = {'pre':'block2.4'}
relation_map['block2.10'] = {'pre':'block2.7'}
relation_map['block2.13'] = {'pre':'block2.10'}
relation_map['block3.1'] = {'pre':'block2.13'}
relation_map['block3.4'] = {'pre':'block3.1'}
relation_map['block3.7'] = {'pre':'block3.4'}
relation_map['block3.10'] = {'pre':'block3.7'}
relation_map['block3.13'] = {'pre':'block3.10'}

relation_map['proj5'] = {'pre':'block1.24'}
relation_map['proj4'] = {'pre':'block2.13'}
relation_map['proj3'] = {'pre':'block3.13'}

In [4]:
# 找出为bn的层，方式：如果盖层存在running_mean，则为带bn的层
bn_layers = []
for k in pretrained_model.keys():
    if 'running_mean' in k:
        print(k[:-13])
        bn_layers.append(k[:-13])

module.fcos_body.backbone.block1.1
module.fcos_body.backbone.block1.5
module.fcos_body.backbone.block1.9
module.fcos_body.backbone.block1.12
module.fcos_body.backbone.block1.15
module.fcos_body.backbone.block1.19
module.fcos_body.backbone.block1.22
module.fcos_body.backbone.block1.25
module.fcos_body.backbone.block2.2
module.fcos_body.backbone.block2.5
module.fcos_body.backbone.block2.8
module.fcos_body.backbone.block2.11
module.fcos_body.backbone.block2.14
module.fcos_body.backbone.block3.2
module.fcos_body.backbone.block3.5
module.fcos_body.backbone.block3.8
module.fcos_body.backbone.block3.11
module.fcos_body.backbone.block3.14


In [8]:
ckpt_path = "./model_59.pth"
pretrained_model = torch.load(ckpt_path, map_location='cpu')
pruned_model = deepcopy(pretrained_model)

# 砍掉30%的权重
p_prune = 0.3
kept_in_idx = None
for k,v in pretrained_model.items():
    indicator_k = '.'.join(k.split('.')[:-1])
    # 找到不在bn层，并且在backbone里的参数
    if indicator_k not in bn_layers and 'backbone' in indicator_k:
        
        # 先判断in channels需要剪多少， 输入层不剪in channel
        if 'block1.0' in indicator_k:
            pass
        else:
            # 先按照设好的映射字典，找到前一个weights
            pre_weight_name = '.'.join(indicator_k.split('.')[:-2])+'.'+relation_map['.'.join(indicator_k.split('.')[-2:])]['pre'] + '.weight'
            pre_weight = pretrained_model[pre_weight_name]
            # print(pre_weight_name, tmp.shape, pre_weight.shape)
            # pre_weight = pretrained_model[indicator_k+'.weight']
            pre_out_channels = pre_weight.size(0)
            pre_weight_norm = torch.norm(pre_weight.view(pre_out_channels,-1), dim=1)  # (pre_out_channels)
            pre_num_pruning = int(pre_out_channels * p_prune + 0.5)
            _, kept_in_idx = torch.topk(pre_weight_norm, pre_out_channels-pre_num_pruning, largest=True)
            print(pre_out_channels, len(kept_in_idx))
            ############################################################
            # 先剪weight的输入channel数
            ############################################################
            pruned_model[k] = pruned_model[k][:, kept_in_idx]
            
        #print('indicator_k: ', indicator_k)    
        # 决定输出层剪多少, 四舍五入
        out_channels = v.size(0) # v的形状为(out_channels, in_channels, H, W)
        num_pruning = int(out_channels * p_prune + 0.5)
        weight_norm = torch.norm(v.view(out_channels,-1), dim=1)  # (out_channels)
        # 选出weight的l2 norm最小的k层权重
        
        _, kept_out_idx = torch.topk(weight_norm, out_channels-num_pruning, largest=True)
        
        ############################################################
        # 拿到要保留的idx后，开始剪枝
        ############################################################
        # Step1. 剪weights的输出channels
        pruned_model[k] = pruned_model[k][kept_out_idx]
        
        # Step2. 剪bn的weight, bias, moving, moving var
        # 将e.g. module.fcos_body.backbone.block1.0最后一个数字加1，可以得到相应的bn层
        name_bn_layer = '.'.join(indicator_k.split('.')[:-1])+'.'+str(int(indicator_k.split('.')[-1])+1)
        
        name_bn_weight = name_bn_layer + '.weight'
        pruned_model[name_bn_weight] = pruned_model[name_bn_weight][kept_out_idx]
        
        name_bn_bias = name_bn_layer + '.bias'
        pruned_model[name_bn_bias] = pruned_model[name_bn_bias][kept_out_idx]
        
        name_bn_running_mean = name_bn_layer + '.running_mean'
        pruned_model[name_bn_running_mean] = pruned_model[name_bn_running_mean][kept_out_idx]
        
        name_bn_running_var = name_bn_layer + '.running_var'
        pruned_model[name_bn_running_var] = pruned_model[name_bn_running_var][kept_out_idx]
        
        # Step3. 顺便剪FPN的lateral convs
        # block1.24 负责prj_3，该卷积的输入channel数应等于block1.24剪完的输出channel数,其他同理
        if 'block1.24' in indicator_k:
            pruned_model['module.fcos_body.fpn.prj_3.weight'] = pruned_model['module.fcos_body.fpn.prj_3.weight'][:,kept_out_idx]
        elif 'block2.13' in indicator_k:
            pruned_model['module.fcos_body.fpn.prj_4.weight'] = pruned_model['module.fcos_body.fpn.prj_4.weight'][:,kept_out_idx]
        elif 'block3.13' in indicator_k:
            pruned_model['module.fcos_body.fpn.prj_5.weight'] = pruned_model['module.fcos_body.fpn.prj_5.weight'][:,kept_out_idx]

32 22
64 45
128 90
64 45
128 90
256 179
128 90
256 179
512 358
256 179
512 358
256 179
512 358
1024 717
512 358
1024 717
512 358


In [11]:
# 查看剪枝后的模型各层输出channel数
for k,v in pruned_model.items():
    if len(v.shape)>0:
        print(k, v.shape, pretrained_model[k].shape)

module.fcos_body.backbone.block1.0.weight torch.Size([22, 3, 3, 3]) torch.Size([32, 3, 3, 3])
module.fcos_body.backbone.block1.1.weight torch.Size([22]) torch.Size([32])
module.fcos_body.backbone.block1.1.bias torch.Size([22]) torch.Size([32])
module.fcos_body.backbone.block1.1.running_mean torch.Size([22]) torch.Size([32])
module.fcos_body.backbone.block1.1.running_var torch.Size([22]) torch.Size([32])
module.fcos_body.backbone.block1.4.weight torch.Size([45, 22, 3, 3]) torch.Size([64, 32, 3, 3])
module.fcos_body.backbone.block1.5.weight torch.Size([45]) torch.Size([64])
module.fcos_body.backbone.block1.5.bias torch.Size([45]) torch.Size([64])
module.fcos_body.backbone.block1.5.running_mean torch.Size([45]) torch.Size([64])
module.fcos_body.backbone.block1.5.running_var torch.Size([45]) torch.Size([64])
module.fcos_body.backbone.block1.8.weight torch.Size([90, 45, 3, 3]) torch.Size([128, 64, 3, 3])
module.fcos_body.backbone.block1.9.weight torch.Size([90]) torch.Size([128])
module.fco

In [10]:
torch.save(pruned_model, 'darknet19_prune30.pth')