# Target Dropout: Column wise target dropout based on the different PE-group sizes

According to the PE-wise threshold-based group lasso pruning(Yang, AAAI 2020), they consider the PE-group size as the basic pruning unit then perform the structured pruning to introduce the sparsity into the model.

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns;sns.set()

In this report, I'm going to reshpe the 4-D weight tensor into 2-D matrix based on the different PE-group sizes. The PE-size varying from 16 to 2. Follow the Target Dropout paper, we first load the ResNet32 model:

In [2]:
check_point = torch.load('./decay0.0002_fp_fflf_resnet32/model_best.pth.tar', map_location='cpu')
param = check_point['state_dict']

layers = param.items()
conv_layers = {}

for k,v in layers:
    if len(v.size()) == 4:
        print(f"Layer: {k}, shape: {list(v.size())}")
        conv_layers.update({k:v})

Layer: conv_1_3x3.weight, shape: [16, 3, 3, 3]
Layer: stage_1.0.conv_a.weight, shape: [16, 16, 3, 3]
Layer: stage_1.0.conv_b.weight, shape: [16, 16, 3, 3]
Layer: stage_1.1.conv_a.weight, shape: [16, 16, 3, 3]
Layer: stage_1.1.conv_b.weight, shape: [16, 16, 3, 3]
Layer: stage_1.2.conv_a.weight, shape: [16, 16, 3, 3]
Layer: stage_1.2.conv_b.weight, shape: [16, 16, 3, 3]
Layer: stage_1.3.conv_a.weight, shape: [16, 16, 3, 3]
Layer: stage_1.3.conv_b.weight, shape: [16, 16, 3, 3]
Layer: stage_1.4.conv_a.weight, shape: [16, 16, 3, 3]
Layer: stage_1.4.conv_b.weight, shape: [16, 16, 3, 3]
Layer: stage_2.0.conv_a.weight, shape: [32, 16, 3, 3]
Layer: stage_2.0.conv_b.weight, shape: [32, 32, 3, 3]
Layer: stage_2.1.conv_a.weight, shape: [32, 32, 3, 3]
Layer: stage_2.1.conv_b.weight, shape: [32, 32, 3, 3]
Layer: stage_2.2.conv_a.weight, shape: [32, 32, 3, 3]
Layer: stage_2.2.conv_b.weight, shape: [32, 32, 3, 3]
Layer: stage_2.3.conv_a.weight, shape: [32, 32, 3, 3]
Layer: stage_2.3.conv_b.weight, sha

Use the second convolutional layer of the second stage as the example, reshape the 4-D tensor into 2-D matrix based on the different PE-group size.

In [3]:
w_l = conv_layers['stage_2.1.conv_a.weight']
print(f"example layer: {list(w_l.size())}")

example layer: [32, 32, 3, 3]


In [4]:
grp_size = [2, 4, 8, 16]
print(f"sweep group size: {grp_size}")

sweep group size: [2, 4, 8, 16]


In [5]:
def reshape_2_2D(input, g):
    w_i = input
    num_group = w_i.size(0) * w_i.size(1) // g 
    
    reshape_layer = w_i.view(g * w_i.size(2) * w_i.size(3), num_group)  # reshape the weight tensor into 4-D matrix
    return reshape_layer

for i, g in enumerate(grp_size):
    w_i = w_l
    
    reshape_layer = reshape_2_2D(w_i, g)
    print(f"group size={g}, shape={list(reshape_layer.size())}")

group size=2, shape=[18, 512]
group size=4, shape=[36, 256]
group size=8, shape=[72, 128]
group size=16, shape=[144, 64]


In [10]:
def forward(input, col_size=4, alpha=0.5, gamma=0.5):
    w_i = reshape_2_2D(input, col_size)
    print(f"group size={col_size}, shape={list(w_i.size())}")
    
    grp_values = w_i.norm(p=2, dim=0)
    print(f'grp values size={grp_values.size()}')
    
    sorted_col, indices = torch.sort(grp_values.contiguous().view(-1), dim=0)
    print(sorted_col.size())

    th_idx = int(grp_values.numel() * gamma)
    threshold = sorted_col[th_idx]
    print(f"threshold L2 norm: {threshold}, idx={th_idx}")
    
    mask_small = 1 - grp_values.gt(threshold).float() # mask for blocks candidates for pruning
    mask_dropout = torch.rand_like(grp_values).lt(alpha).float()
    
    mask_keep = 1 - mask_small * mask_dropout
    
    mask_keep_2d = mask_keep.expand(w_i.size()) 
    print(mask_keep)
    print(mask_keep_2d[:,3])
    
    mask_keep_original = mask_keep_2d.resize_as_(input)
    return mask_keep_original

In [14]:
mask_keep_original_test = forward(w_l)
mask_keep_2d_test = mask_keep_original_test.view(36, 256)
print(mask_keep_2d_test[:,1].float())

group size=4, shape=[36, 256]
grp values size=torch.Size([256])
torch.Size([256])
threshold L2 norm: 0.28617072105407715, idx=128
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0.,
        1., 1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
        0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.,
        1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0.,
        1., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 1.,
        0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
        1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.,
        1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0.

In [None]:
# block based target dropout
block_size = 8

def td_forward(input):
        # sort blocks by mean absolute value
        block_values = F.avg_pool2d(input.data.abs().permute(2,3,0,1),
                        kernel_size=(block_size, block_size),
                        stride=(block_size, block_size))
        
        print(f'block values size: {block_values.size()}')
        
        sorted_block_values, indices = torch.sort(block_values.contiguous().view(-1))
        
        thre_index = int(block_values.data.numel() * gamma)
        threshold = sorted_block_values[thre_index]
        mask_small = 1 - block_values.gt(threshold).float() # mask for blocks candidates for pruning
        mask_dropout = torch.rand_like(block_values).lt(alpha).float()
        mask_keep = 1.0 - mask_small * mask_dropout
        
        print(f'mask_keep size = {mask_keep.size()}')
        
        mask_keep_original = F.interpolate(mask_keep, 
                            scale_factor=(block_size, block_size)).permute(2,3,0,1)
        print(f'mask keep original size = {mask_keep_original.size()}')
        return mask_keep_original

block_mask_w = td_forward(w_l)