In [1]:
#===================================================================================================
# global pruning(가지치기)
# - 모델 전체에서 모듈간 연결 영향력이 작은 파라메터들을 가지지치하는 방식
# - 가장 효과적이고, 모델에 한꺼번에 적용이 가능함
#
# 참고 자료 : https://huffon.github.io/2020/03/15/torch-pruning/
# 참고 소스 : https://github.com/Huffon/nlp-various-tutorials/blob/master/pruning-bert.ipynb
#===================================================================================================
import torch
import torch.nn.utils.prune as prune

In [2]:
# BERT 모델 로딩
encoder_layer_num = 6  #bert면 12, distilbert면 = 6

from transformers import BertModel, DistilBertModel, DistilBertForSequenceClassification
model_path = '../../../model/distilbert/distilbert-0331-TS-nli-0.1-10'

#model = BertModel.from_pretrained(model_path)
#model = DistilBertModel.from_pretrained(model_path)
model=DistilBertForSequenceClassification.from_pretrained(model_path, num_labels=3)
print(model)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(167550, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
      

In [3]:
print(model.num_parameters())

172193283


In [4]:
# embedding 레이어 파라메터 출력해봄
embeddingmodule = model.distilbert.embeddings.word_embeddings
list(embeddingmodule.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 0.0269, -0.0009, -0.0439,  ...,  0.0020, -0.0098,  0.0094],
          [-0.0190, -0.0209, -0.0088,  ..., -0.0216, -0.0193, -0.0079],
          [-0.0118, -0.0189, -0.0004,  ..., -0.0270, -0.0167, -0.0258],
          ...,
          [-0.0197, -0.0128, -0.0483,  ..., -0.0389, -0.0192, -0.0250],
          [-0.0253, -0.0202, -0.0244,  ..., -0.0371, -0.0323, -0.0193],
          [ 0.0077, -0.0302, -0.0408,  ..., -0.0224, -0.0575, -0.0220]],
         requires_grad=True))]

In [5]:
# transformer.layer[0].attenton.q_lin 레이어 출력해봄
module1 = model.distilbert.transformer.layer[0].attention.q_lin
list(module1.named_parameters())

[('weight',
  Parameter containing:
  tensor([[-0.0176,  0.0006, -0.0255,  ..., -0.0118, -0.0157,  0.0184],
          [ 0.1103,  0.0620, -0.0436,  ...,  0.0194, -0.0158, -0.0365],
          [ 0.0516,  0.0002,  0.0106,  ..., -0.0203, -0.0609, -0.0059],
          ...,
          [-0.0326,  0.0789, -0.0282,  ...,  0.0675, -0.0544,  0.0092],
          [-0.0324, -0.0388,  0.0506,  ..., -0.0011, -0.0232,  0.0773],
          [ 0.0031, -0.0976,  0.0209,  ...,  0.0206, -0.0318, -0.0092]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([-6.8313e-02,  6.7761e-02,  1.6668e-01,  1.2392e-01,  2.9475e-01,
          -1.0670e-01, -1.5829e-01,  2.3700e-01, -4.7921e-01, -2.6351e-01,
          -5.6057e-02, -4.0667e-01,  1.8796e-01, -8.7701e-02, -5.2525e-02,
           2.6205e-01, -6.7708e-02,  2.6149e-02, -3.9226e-01,  2.4256e-01,
           2.9250e-01,  4.2455e-02, -3.9571e-01, -3.4621e-02, -1.9131e-01,
           5.8470e-01,  1.8993e-01,  9.0211e-02,  3.6801e-01,  2.9665e-01,
 

In [6]:
print(module1.weight.shape) 
print(module1.bias.shape)

torch.Size([768, 768])
torch.Size([768])


In [7]:
module2 = model.distilbert.transformer.layer[0].attention.out_lin
list(module2.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 2.1692e-02, -1.0404e-03,  3.8681e-03,  ...,  4.6904e-03,
           -6.8416e-03,  1.2067e-02],
          [ 5.1262e-03, -3.4173e-02,  8.2244e-03,  ..., -4.0997e-02,
           -5.2522e-02,  8.1092e-03],
          [ 7.2423e-03, -2.3219e-03,  2.7945e-02,  ...,  3.2959e-02,
            2.1564e-02, -2.8473e-03],
          ...,
          [-1.4599e-02, -2.1674e-02, -1.5329e-03,  ...,  1.6566e-02,
            2.8185e-02,  2.4331e-02],
          [-7.6789e-03, -2.2596e-02,  1.0172e-03,  ...,  9.8938e-03,
           -1.3059e-02,  1.0753e-02],
          [-9.2413e-06, -2.5446e-02, -3.9522e-02,  ...,  3.9514e-04,
           -1.6852e-02,  4.4802e-02]], requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([ 6.2471e-02,  6.5976e-02,  1.2663e-02, -1.5958e-01, -3.4438e-02,
           8.6972e-02, -6.3077e-03, -2.8673e-02,  2.3975e-02, -3.0535e-02,
           2.2055e-02,  3.5911e-02,  6.3401e-02,  1.4193e-01, -8.8301e-06,
          -6.0403e-03, -1

In [8]:
print(module2.weight.shape) 
print(module2.bias.shape)

torch.Size([768, 768])
torch.Size([768])


In [9]:
# transformer.layer[0].ffn.lin1 레이어 출력해봄
module3 = model.distilbert.transformer.layer[0].ffn.lin1
list(module3.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 0.0334, -0.0455,  0.0813,  ..., -0.0059, -0.0378, -0.0113],
          [-0.0289,  0.0478,  0.0159,  ..., -0.0403, -0.0670,  0.0322],
          [-0.0055,  0.0556,  0.0129,  ..., -0.0686, -0.0723,  0.0329],
          ...,
          [-0.0040,  0.0732, -0.0095,  ..., -0.0527,  0.0133, -0.0238],
          [-0.0237,  0.1044, -0.0009,  ..., -0.0174, -0.0361,  0.0447],
          [-0.0470,  0.0253,  0.0155,  ..., -0.0655,  0.0209,  0.0060]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([-0.1553, -0.0331, -0.1477,  ...,  0.0213, -0.1634, -0.0975],
         requires_grad=True))]

In [10]:
# transformer.layer[0].ffn.lin1 레이어 출력해봄
module4 = model.distilbert.transformer.layer[0].ffn.lin2
list(module4.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 0.0507,  0.0234,  0.0272,  ..., -0.0213,  0.0287,  0.0295],
          [ 0.0034, -0.0230,  0.0321,  ..., -0.0344,  0.0244, -0.0722],
          [ 0.0362, -0.0423,  0.0281,  ..., -0.0516,  0.0142, -0.0469],
          ...,
          [-0.0182, -0.0271, -0.0829,  ...,  0.0187,  0.0032,  0.0447],
          [-0.0271, -0.0010,  0.0434,  ..., -0.0393,  0.0398, -0.0302],
          [ 0.0332,  0.0008, -0.0188,  ..., -0.0106, -0.0335,  0.0261]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([-8.2931e-03,  3.2956e-02, -5.0556e-02,  1.9224e-01,  2.4758e-02,
           1.6209e-02,  1.9909e-02,  2.8969e-02,  4.9089e-02,  5.2501e-02,
           2.1930e-02, -3.1483e-03, -4.5500e-02, -1.4619e-02, -2.4293e-02,
          -3.8564e-02, -5.5504e-02, -5.7840e-02,  4.5654e-02,  3.0383e-02,
          -8.7322e-03,  9.0415e-02, -4.3904e-02, -3.5723e-02,  1.0012e-01,
          -6.8407e-02, -7.5555e-03, -2.1584e-02, -4.1634e-02, -6.3506e-02,
 

In [11]:
# pruning 하기전에 0번째 attention.self.key 레이어 파라메터를 확인해 봄
module = model.distilbert.transformer.layer[0].attention.q_lin
print(list(module.named_parameters()))
print('\n')

# pruning 하기전 weight가 0인 계수 출력
print('*0 계수:{}'.format((module.weight == 0).sum()))

[('weight', Parameter containing:
tensor([[-0.0176,  0.0006, -0.0255,  ..., -0.0118, -0.0157,  0.0184],
        [ 0.1103,  0.0620, -0.0436,  ...,  0.0194, -0.0158, -0.0365],
        [ 0.0516,  0.0002,  0.0106,  ..., -0.0203, -0.0609, -0.0059],
        ...,
        [-0.0326,  0.0789, -0.0282,  ...,  0.0675, -0.0544,  0.0092],
        [-0.0324, -0.0388,  0.0506,  ..., -0.0011, -0.0232,  0.0773],
        [ 0.0031, -0.0976,  0.0209,  ...,  0.0206, -0.0318, -0.0092]],
       requires_grad=True)), ('bias', Parameter containing:
tensor([-6.8313e-02,  6.7761e-02,  1.6668e-01,  1.2392e-01,  2.9475e-01,
        -1.0670e-01, -1.5829e-01,  2.3700e-01, -4.7921e-01, -2.6351e-01,
        -5.6057e-02, -4.0667e-01,  1.8796e-01, -8.7701e-02, -5.2525e-02,
         2.6205e-01, -6.7708e-02,  2.6149e-02, -3.9226e-01,  2.4256e-01,
         2.9250e-01,  4.2455e-02, -3.9571e-01, -3.4621e-02, -1.9131e-01,
         5.8470e-01,  1.8993e-01,  9.0211e-02,  3.6801e-01,  2.9665e-01,
        -2.1002e-01,  2.7735e-01, 

In [12]:
# 모델 구조에서 pruning 할 모듈들만 지정해서 처리함
# => ****모델 구조에 맞게 모듈들을 재정의 해야 함.

# 튜플 자료형
parameters_to_prune = (
    (model.distilbert.embeddings.word_embeddings, "weight"),                  # embeddings 모듈에는 bias는 없음
                      ) 

for i in range(encoder_layer_num):
    parameters_to_prune += (
        (model.distilbert.transformer.layer[i].attention.q_lin, "weight"),     # attention.self.key 모듈
        (model.distilbert.transformer.layer[i].attention.k_lin, "weight"),   # attention.self.query 모듈
        (model.distilbert.transformer.layer[i].attention.v_lin, "weight"),   # attention.self.value 모듈
        (model.distilbert.transformer.layer[i].attention.out_lin, "weight"), # attention.output.dense 모듈
        (model.distilbert.transformer.layer[i].ffn.lin1, "weight"),     # intermediate.dense 모듈
        (model.distilbert.transformer.layer[i].ffn.lin2, "weight"),           # output.dense 모듈
        
        (model.distilbert.transformer.layer[i].attention.q_lin, "bias"),     # attention.self.key 모듈
        (model.distilbert.transformer.layer[i].attention.k_lin, "bias"),   # attention.self.query 모듈
        (model.distilbert.transformer.layer[i].attention.v_lin, "bias"),   # attention.self.value 모듈
        (model.distilbert.transformer.layer[i].attention.out_lin, "bias"), # attention.output.dense 모듈
        (model.distilbert.transformer.layer[i].ffn.lin1, "bias"),           # intermediate.dense 모듈
        (model.distilbert.transformer.layer[i].ffn.lin2, "bias"),           # output.dense 모듈
    )
    
# global_unstructured 로 적용
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2
)

In [None]:
# 비율 출력 해봄
# =>위 결과를 통해 모듈 내 각 파라미터의 Pruning 비율은 20%가 되지 않지만 전체 Sparsity가 20%가 되는 것을 확인할 수 있습니다.
# => ****모델 구조에 맞게 모듈들을 재정의 해야 함.
'''
for i in range(encoder_layer_num):
    print(
        "Sparsity in Layer {}-th key weight: {:.2f}%".format(
            i+1,
            100. * float(torch.sum(model.encoder.layer[i].attention.self.key.weight == 0))
            / float(model.encoder.layer[i].attention.self.key.weight.nelement())
        )
    )
    print(
        "Sparsity in Layer {}-th query weightt: {:.2f}%".format(
            i+1,
            100. * float(torch.sum(model.encoder.layer[i].attention.self.query.weight == 0))
            / float(model.encoder.layer[i].attention.self.query.weight.nelement())
        )
    )
    print(
        "Sparsity in Layer {}-th value weight: {:.2f}%".format(
            i+1,
            100. * float(torch.sum(model.encoder.layer[i].attention.self.value.weight == 0))
            / float(model.encoder.layer[i].attention.self.value.weight.nelement())
        )
    )
    print()

    
numerator, denominator = 0, 0
for i in range(encoder_layer_num):
    numerator += torch.sum(model.encoder.layer[i].attention.self.key.weight == 0)
    numerator += torch.sum(model.encoder.layer[i].attention.self.query.weight == 0)
    numerator += torch.sum(model.encoder.layer[i].attention.self.value.weight == 0)

    denominator += model.encoder.layer[i].attention.self.key.weight.nelement()
    denominator += model.encoder.layer[i].attention.self.query.weight.nelement()
    denominator += model.encoder.layer[i].attention.self.value.weight.nelement()
    
print("Global sparsity: {:.2f}%".format(100. * float(numerator) / float(denominator)))
'''

In [18]:
# pruning 영구 적용시킴 
# => ****모델 구조에 맞게 모듈들을 재정의 해야 함.

prune.remove(model.distilbert.embeddings.word_embeddings, "weight")

for i in range(encoder_layer_num):
    prune.remove(model.distilbert.transformer.layer[i].attention.q_lin, "weight")     # attention.self.key 모듈
    prune.remove(model.distilbert.transformer.layer[i].attention.k_lin, "weight")   # attention.self.query 모듈
    prune.remove(model.distilbert.transformer.layer[i].attention.v_lin, "weight")   # attention.self.value 모듈
    prune.remove(model.distilbert.transformer.layer[i].attention.out_lin, "weight") # attention.output.dense 모듈
    prune.remove(model.distilbert.transformer.layer[i].ffn.lin1, "weight")     # intermediate.dense 모듈
    prune.remove(model.distilbert.transformer.layer[i].ffn.lin2, "weight")           # output.dense 모듈
        
    prune.remove(model.distilbert.transformer.layer[i].attention.q_lin, "bias")     # attention.self.key 모듈
    prune.remove(model.distilbert.transformer.layer[i].attention.k_lin, "bias")   # attention.self.query 모듈
    prune.remove(model.distilbert.transformer.layer[i].attention.v_lin, "bias")   # attention.self.value 모듈
    prune.remove(model.distilbert.transformer.layer[i].attention.out_lin, "bias") # attention.output.dense 모듈
    prune.remove(model.distilbert.transformer.layer[i].ffn.lin1, "bias")     # intermediate.dense 모듈
    prune.remove(model.distilbert.transformer.layer[i].ffn.lin2, "bias")           # output.dense 모듈
        

In [13]:
dict(model.named_buffers()).keys()

dict_keys(['distilbert.embeddings.position_ids', 'distilbert.embeddings.word_embeddings.weight_mask', 'distilbert.transformer.layer.0.attention.q_lin.weight_mask', 'distilbert.transformer.layer.0.attention.q_lin.bias_mask', 'distilbert.transformer.layer.0.attention.k_lin.weight_mask', 'distilbert.transformer.layer.0.attention.k_lin.bias_mask', 'distilbert.transformer.layer.0.attention.v_lin.weight_mask', 'distilbert.transformer.layer.0.attention.v_lin.bias_mask', 'distilbert.transformer.layer.0.attention.out_lin.weight_mask', 'distilbert.transformer.layer.0.attention.out_lin.bias_mask', 'distilbert.transformer.layer.0.ffn.lin1.weight_mask', 'distilbert.transformer.layer.0.ffn.lin1.bias_mask', 'distilbert.transformer.layer.0.ffn.lin2.weight_mask', 'distilbert.transformer.layer.0.ffn.lin2.bias_mask', 'distilbert.transformer.layer.1.attention.q_lin.weight_mask', 'distilbert.transformer.layer.1.attention.q_lin.bias_mask', 'distilbert.transformer.layer.1.attention.k_lin.weight_mask', 'disti

In [15]:
# pruning 후 0번째 attention.self.key 레이어 파라메터를 확인해 봄
module = model.distilbert.transformer.layer[0].attention.q_lin
print(list(module.named_parameters()))
print('\n')

# pruning 후 weight가 0인 계수 출력
print('*0 계수:{}'.format((module.weight == 0).sum()))

[('weight_orig', Parameter containing:
tensor([[-0.0176,  0.0006, -0.0255,  ..., -0.0118, -0.0157,  0.0184],
        [ 0.1103,  0.0620, -0.0436,  ...,  0.0194, -0.0158, -0.0365],
        [ 0.0516,  0.0002,  0.0106,  ..., -0.0203, -0.0609, -0.0059],
        ...,
        [-0.0326,  0.0789, -0.0282,  ...,  0.0675, -0.0544,  0.0092],
        [-0.0324, -0.0388,  0.0506,  ..., -0.0011, -0.0232,  0.0773],
        [ 0.0031, -0.0976,  0.0209,  ...,  0.0206, -0.0318, -0.0092]],
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-6.8313e-02,  6.7761e-02,  1.6668e-01,  1.2392e-01,  2.9475e-01,
        -1.0670e-01, -1.5829e-01,  2.3700e-01, -4.7921e-01, -2.6351e-01,
        -5.6057e-02, -4.0667e-01,  1.8796e-01, -8.7701e-02, -5.2525e-02,
         2.6205e-01, -6.7708e-02,  2.6149e-02, -3.9226e-01,  2.4256e-01,
         2.9250e-01,  4.2455e-02, -3.9571e-01, -3.4621e-02, -1.9131e-01,
         5.8470e-01,  1.8993e-01,  9.0211e-02,  3.6801e-01,  2.9665e-01,
        -2.1002e-01,  2.

In [19]:
# model 저장해 봄
import os
out_path = '../../../model/distilbert/distilbert-0331-TS-nli-0.1-10-pruing-global'
os.makedirs(out_path, exist_ok=True)
model.save_pretrained(out_path)

In [21]:
model.num_parameters()

172193283