模型剪枝

In [1]:
import os
import sys
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('.'))

In [5]:
import torch
from dltime.models.FCN import FCN
from torch import nn
from thop import profile
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from dltime.models.inception_atten import TSInceptionSelfAttnEncoderClassifier

In [3]:
ln = 2

Inception

In [6]:
model = TSInceptionSelfAttnEncoderClassifier(feat_dim=5, max_len=64, d_model=256, num_layers=2, num_classes=3, num_heads=4)

In [7]:
x = torch.randn(1, 5, 64)
model(x)

tensor([[-0.2571, -0.2522, -0.0929]], grad_fn=<AddmmBackward>)

In [8]:
profile(model, inputs=(x,))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool1d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


(64848896.0, 1009155.0)

FCN

In [22]:
filters = [32, 64, 32]
new_filters = [16, 32, 16]

In [23]:
model = FCN(c_in=5, c_out=3, layers=filters)
model.load_state_dict(torch.load('./outputs/FCN_all_2022-09-29 00.40_32_64_32_prune.pth'))
for name, param in model.named_parameters():
    print(name, param.size())

state_dict = model.state_dict()

convblock1.conv1d.conv1d.weight torch.Size([32, 5, 7])
convblock1.conv1d.conv1d.bias torch.Size([32])
convblock1.bn.weight torch.Size([32])
convblock1.bn.bias torch.Size([32])
convblock2.conv1d.conv1d.weight torch.Size([64, 32, 5])
convblock2.conv1d.conv1d.bias torch.Size([64])
convblock2.bn.weight torch.Size([64])
convblock2.bn.bias torch.Size([64])
convblock3.conv1d.conv1d.weight torch.Size([32, 64, 3])
convblock3.conv1d.conv1d.bias torch.Size([32])
convblock3.bn.weight torch.Size([32])
convblock3.bn.bias torch.Size([32])
fc.weight torch.Size([3, 32])
fc.bias torch.Size([3])


In [24]:
new_model = FCN(c_in=5, c_out=3, layers=new_filters)
new_state_dict = new_model.state_dict()

In [38]:
profile(new_model, inputs=(x,))

[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool1d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


(315456.0, 4899.0)

In [25]:
x = torch.randn(1, 5, 64)
new_model(x)

tensor([[0.3798, 0.2862, 0.3341]], grad_fn=<SoftmaxBackward>)

第一层卷积

In [26]:
module = model.convblock1.conv1d.conv1d
for name, param in module.named_parameters():
    print(name, param.size())

weight torch.Size([32, 5, 7])
bias torch.Size([32])


In [27]:
res0 = []
if new_filters[0] != filters[0]:
    rate = 1 - new_filters[0] / filters[0]
    prune.ln_structured(module, name="weight", amount=rate, n=ln, dim=0)
    for i in range(module.weight.size(0)):
        if torch.sum(module.weight[i]) != 0.0:
            res0.append(i)
else:
    res0 = list(range(module.weight.size(0)))

In [28]:
new_state_dict['convblock1.conv1d.conv1d.weight'] = module.weight[res0].detach()
new_state_dict['convblock1.conv1d.conv1d.bias'] = module.bias[res0].detach()
new_state_dict['convblock1.bn.weight'] = state_dict['convblock1.bn.weight'][res0].detach()
new_state_dict['convblock1.bn.bias'] = state_dict['convblock1.bn.bias'][res0].detach()

第二层卷积

In [29]:
module = model.convblock2.conv1d.conv1d
for name, param in module.named_parameters():
    print(name, param.size())

weight torch.Size([64, 32, 5])
bias torch.Size([64])


In [30]:
res1 = []
if new_filters[1] != filters[1]:
    rate = 1 - new_filters[1] / filters[1]
    prune.ln_structured(module, name="weight", amount=rate, n=ln, dim=0)
    for i in range(module.weight.size(0)):
        if torch.sum(module.weight[i]) != 0.0:
            res1.append(i)
else:
    res1 = list(range(module.weight.size(0)))

In [31]:
new_state_dict['convblock2.conv1d.conv1d.weight'] = module.weight[res1][:, res0].detach()
new_state_dict['convblock2.conv1d.conv1d.bias'] = module.bias[res1].detach()
new_state_dict['convblock2.bn.weight'] = state_dict['convblock2.bn.weight'][res1].detach()
new_state_dict['convblock2.bn.bias'] = state_dict['convblock2.bn.bias'][res1].detach()

第三层卷积

In [32]:
module = model.convblock3.conv1d.conv1d
for name, param in module.named_parameters():
    print(name, param.size())

weight torch.Size([32, 64, 3])
bias torch.Size([32])


In [33]:
res2 = []
if new_filters[2] != filters[2]:
    rate = 1 - new_filters[2] / filters[2]
    prune.ln_structured(module, name="weight", amount=rate, n=ln, dim=0)
    for i in range(module.weight.size(0)):
        if torch.sum(module.weight[i]) != 0.0:
            res2.append(i)
else:
    res2 = list(range(module.weight.size(0)))

In [34]:
new_state_dict['convblock3.conv1d.conv1d.weight'] = module.weight[res2][:, res1].detach()
new_state_dict['convblock3.conv1d.conv1d.bias'] = module.bias[res2].detach()
new_state_dict['convblock3.bn.weight'] = state_dict['convblock3.bn.weight'][res2].detach()
new_state_dict['convblock3.bn.bias'] = state_dict['convblock3.bn.bias'][res2].detach()

全连接

In [35]:
new_state_dict['fc.weight'] = state_dict['fc.weight'][:, res2].detach()

In [36]:
state_dict['convblock3.conv1d.conv1d.bias'][res2]

tensor([ 0.0285, -0.0039,  0.0237,  0.0143, -0.0175, -0.0329, -0.0269, -0.0069,
        -0.0303, -0.0196,  0.0258,  0.0339, -0.0101, -0.0033, -0.0314, -0.0082])

In [37]:
torch.save(new_state_dict, 'outputs/FCN_prune_layer_16_32_16.pth')