In [1]:
#===================================================================================================
# 로컬 pruning(가지치기)
# - torch.nn.utils.prune 모듈을 활용한 BERT 모델에 대해 로컬 weight, bias들을 pruning 함
#
# 참고 자료 : https://huffon.github.io/2020/03/15/torch-pruning/
# 참고 소스 : https://github.com/Huffon/nlp-various-tutorials/blob/master/pruning-bert.ipynb
#===================================================================================================

import torch
import torch.nn.utils.prune as prune

# pruning 을 위해서는 1.4.0 이상 torch가 필요
torch.__version__

'1.10.1'

In [None]:
# BERT 모델 로딩
from transformers import BertModel
model_path = '../../../model/bert/bmc-fpt-bong_corpus_mecab-0428'
model = BertModel.from_pretrained(model_path)
print(model)

In [4]:
# 모델 layer 
model.encoder.layer

ModuleList(
  (0): BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (1): BertLayer(
    (attention): BertAttention(
      (self)

In [5]:
# attention 모듈만 불러옴
model.encoder.layer[0].attention.self

BertSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [6]:
# attention 모듈에서 key에 Pruning 적용 해볼것임
# key에 named_parameters()에는 weigth, bias가 있음.
module = model.encoder.layer[0].attention.self.key

In [7]:
list(module.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 8.0064e-03, -7.2079e-02, -4.5409e-02,  ...,  1.3022e-02,
           -4.1921e-02, -1.8508e-02],
          [ 4.7319e-02,  5.6065e-02,  7.8608e-03,  ..., -1.5596e-02,
            1.1952e-02,  1.3859e-02],
          [-1.3696e-02, -4.2909e-02,  4.8609e-03,  ..., -3.5625e-03,
           -9.7081e-02, -5.2629e-02],
          ...,
          [-2.3917e-02, -3.2999e-02,  8.4614e-03,  ...,  8.6428e-05,
            1.8967e-04,  1.8264e-03],
          [-9.8836e-02, -8.1632e-02, -2.3575e-02,  ..., -1.8363e-02,
           -5.1585e-02, -1.2072e-01],
          [-1.6689e-02, -4.9597e-02, -2.4316e-02,  ...,  4.4092e-02,
           -3.1870e-02,  3.8483e-02]], requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([ 1.0594e-03,  7.3360e-04,  5.4275e-03, -1.3735e-03,  2.9526e-03,
           5.4779e-03,  2.9630e-03,  6.7515e-03,  4.7061e-03,  3.2703e-03,
          -2.9103e-03,  1.1439e-03, -4.9434e-03,  4.8971e-03,  1.2137e-02,
          -5.1321e-04, -5

In [8]:
# named_buffers 에는 아무것도 없음.
list(module.named_buffers())

[]

In [10]:
# Pruning 시작
# key에서 weigth 중 임의로 30% 파라미터에 pruning 적용
# => 첫번째인자 : 모듈명, name=모듈에 적용할 파라메터명(weight, bias등), amount=0~1 사이의 pruing할 퍼센테지
prune.random_unstructured(module, name='weight', amount=0.3)


Linear(in_features=768, out_features=768, bias=True)

In [11]:
# 모듈에서 weight 파라미터 제거후 weight_org로 대체됨.
list(module.named_parameters())

[('bias',
  Parameter containing:
  tensor([ 1.0594e-03,  7.3360e-04,  5.4275e-03, -1.3735e-03,  2.9526e-03,
           5.4779e-03,  2.9630e-03,  6.7515e-03,  4.7061e-03,  3.2703e-03,
          -2.9103e-03,  1.1439e-03, -4.9434e-03,  4.8971e-03,  1.2137e-02,
          -5.1321e-04, -5.4592e-03, -9.2707e-04,  5.4551e-03, -4.9731e-03,
          -6.3760e-03, -2.2797e-03,  9.0375e-03,  6.7416e-03, -5.5720e-03,
          -8.4298e-03,  2.1314e-03, -7.0118e-03,  1.4222e-03, -3.8760e-04,
           4.2457e-04, -2.0165e-03,  4.2753e-03,  5.7197e-03, -7.3704e-03,
           5.8238e-03, -3.5565e-03, -4.7996e-03, -4.2372e-03, -5.0620e-03,
          -1.4990e-02, -1.2310e-03,  1.3331e-03, -4.7811e-03,  3.9264e-03,
          -3.6671e-03, -5.8450e-04,  4.4344e-03, -7.1720e-03,  2.5168e-03,
          -4.0643e-03,  4.9319e-03, -2.3885e-04, -5.4040e-04,  1.5722e-03,
           7.4370e-03,  5.1830e-03, -8.6837e-03,  4.1268e-03, -1.6794e-03,
          -2.8951e-03,  2.9707e-03,  7.9118e-03, -5.1855e-03,  2.5

In [12]:
# named_buffers 에는 pruning mask가 생성되어 있음.
list(module.named_buffers())

[('weight_mask',
  tensor([[1., 0., 1.,  ..., 0., 1., 0.],
          [1., 0., 1.,  ..., 1., 1., 0.],
          [0., 0., 1.,  ..., 1., 0., 1.],
          ...,
          [1., 1., 1.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 1., 0.],
          [0., 1., 1.,  ..., 1., 1., 1.]]))]

In [13]:
# 모듈의 weight는 위 weight_org 와 weight_mask 을 적용해 계산됨.
# 계산결과 weight는 속성(attribute)로 저장됨
module.weight

tensor([[ 0.0080, -0.0000, -0.0454,  ...,  0.0000, -0.0419, -0.0000],
        [ 0.0473,  0.0000,  0.0079,  ..., -0.0156,  0.0120,  0.0000],
        [-0.0000, -0.0000,  0.0049,  ..., -0.0036, -0.0000, -0.0526],
        ...,
        [-0.0239, -0.0330,  0.0085,  ...,  0.0000,  0.0000,  0.0018],
        [-0.0000, -0.0816, -0.0236,  ..., -0.0000, -0.0516, -0.0000],
        [-0.0000, -0.0496, -0.0243,  ...,  0.0441, -0.0319,  0.0385]],
       grad_fn=<MulBackward0>)

In [15]:
# weight의 size
module.weight.size()

torch.Size([768, 768])

In [16]:
# pruning 후 weight가 0인 계수 출력
(module.weight == 0).sum()

tensor(176947)

In [17]:
# Pruning을 순전파 이전에 적용하기 위해서는 forword_pre_hooks라는 속성 사용
module._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured at 0x7fe2aedc5520>)])

In [18]:
# key의 bias에도 Pruning 적용
# => L1 노름 기준으로 가정 영향력이 작은 amount개의 파라메터 Pruning하도록 l1_unstructured 함수 사용
prune.l1_unstructured(module, name='bias', amount=50)

Linear(in_features=768, out_features=768, bias=True)

In [19]:
# bias_orig 포함되더 있는지 확인
list(module.named_parameters())

[('weight_orig',
  Parameter containing:
  tensor([[ 8.0064e-03, -7.2079e-02, -4.5409e-02,  ...,  1.3022e-02,
           -4.1921e-02, -1.8508e-02],
          [ 4.7319e-02,  5.6065e-02,  7.8608e-03,  ..., -1.5596e-02,
            1.1952e-02,  1.3859e-02],
          [-1.3696e-02, -4.2909e-02,  4.8609e-03,  ..., -3.5625e-03,
           -9.7081e-02, -5.2629e-02],
          ...,
          [-2.3917e-02, -3.2999e-02,  8.4614e-03,  ...,  8.6428e-05,
            1.8967e-04,  1.8264e-03],
          [-9.8836e-02, -8.1632e-02, -2.3575e-02,  ..., -1.8363e-02,
           -5.1585e-02, -1.2072e-01],
          [-1.6689e-02, -4.9597e-02, -2.4316e-02,  ...,  4.4092e-02,
           -3.1870e-02,  3.8483e-02]], requires_grad=True)),
 ('bias_orig',
  Parameter containing:
  tensor([ 1.0594e-03,  7.3360e-04,  5.4275e-03, -1.3735e-03,  2.9526e-03,
           5.4779e-03,  2.9630e-03,  6.7515e-03,  4.7061e-03,  3.2703e-03,
          -2.9103e-03,  1.1439e-03, -4.9434e-03,  4.8971e-03,  1.2137e-02,
          -5.13

In [20]:
# bias_mask가 추가되어 있는지 확인
list(module.named_buffers())

[('weight_mask',
  tensor([[1., 0., 1.,  ..., 0., 1., 0.],
          [1., 0., 1.,  ..., 1., 1., 0.],
          [0., 0., 1.,  ..., 1., 0., 1.],
          ...,
          [1., 1., 1.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 0., 1., 0.],
          [0., 1., 1.,  ..., 1., 1., 1.]])),
 ('bias_mask',
  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
          0., 1., 1., 1., 0., 0., 1., 1., 

In [21]:
# bias 출력
module.bias

tensor([ 0.0011,  0.0007,  0.0054, -0.0014,  0.0030,  0.0055,  0.0030,  0.0068,
         0.0047,  0.0033, -0.0029,  0.0011, -0.0049,  0.0049,  0.0121, -0.0005,
        -0.0055, -0.0009,  0.0055, -0.0050, -0.0064, -0.0023,  0.0090,  0.0067,
        -0.0056, -0.0084,  0.0021, -0.0070,  0.0014, -0.0000,  0.0000, -0.0020,
         0.0043,  0.0057, -0.0074,  0.0058, -0.0036, -0.0048, -0.0042, -0.0051,
        -0.0150, -0.0012,  0.0013, -0.0048,  0.0039, -0.0037, -0.0006,  0.0044,
        -0.0072,  0.0025, -0.0041,  0.0049, -0.0000, -0.0005,  0.0016,  0.0074,
         0.0052, -0.0087,  0.0041, -0.0017, -0.0029,  0.0030,  0.0079, -0.0052,
         0.0026,  0.0016,  0.0076, -0.0000, -0.0059,  0.0068,  0.0025,  0.0037,
         0.0029,  0.0036,  0.0079,  0.0031, -0.0087,  0.0016,  0.0000, -0.0030,
         0.0020, -0.0000,  0.0011, -0.0017,  0.0016,  0.0078,  0.0018,  0.0045,
         0.0010, -0.0035,  0.0024,  0.0057,  0.0042, -0.0055, -0.0057, -0.0044,
         0.0038,  0.0060,  0.0062,  0.00

In [22]:
#bias 사이즈 
module.bias.size()

torch.Size([768])

In [23]:
# bias 중 0인 계수 합
(module.bias==0).sum()

tensor(50)

In [25]:
#_forward_pre_hooks 출력해봄 => weight, biase 2개가 존재함.
module._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured at 0x7fe2aedc5520>),
             (1, <torch.nn.utils.prune.L1Unstructured at 0x7fe2aedc5a60>)])

In [26]:
# 중첩 적용
# => 앞에서 적용한 weight 파라메터에 다른 Pruning을 중첩 적용함.
# => weight에 dim=x 텐서에서 L2 노름을 기준으로 가장 영향력이 작은 파라메터를 Pruning 함.

prune.ln_structured(module, name='weight', amount=0.3, n=2, dim=1)

Linear(in_features=768, out_features=768, bias=True)

In [27]:
# weight에 중첩 pruning 했으므로, mask 0의 계수가 많아졌다.
list(module.named_buffers())[0][1]

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 1.],
        ...,
        [1., 1., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 1., 0., 1.]])

In [28]:
# 실제 weight에도 중첩 pruning 적용함으로 인해 0이 계수가 많아졌음.
module.weight

tensor([[ 0.0080, -0.0000, -0.0000,  ...,  0.0000, -0.0000, -0.0000],
        [ 0.0473,  0.0000,  0.0000,  ..., -0.0156,  0.0000,  0.0000],
        [-0.0000, -0.0000,  0.0000,  ..., -0.0036, -0.0000, -0.0526],
        ...,
        [-0.0239, -0.0330,  0.0000,  ...,  0.0000,  0.0000,  0.0018],
        [-0.0000, -0.0816, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [-0.0000, -0.0496, -0.0000,  ...,  0.0441, -0.0000,  0.0385]],
       grad_fn=<MulBackward0>)

In [29]:
# 중첩 Pruning을 적용하는 객체는 torch.nn.utils.prune.PruningContainer 가됨.

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == 'weight':
        break
        
hook

<torch.nn.utils.prune.PruningContainer at 0x7fe2aedc5700>

In [30]:
# weight에 대해 2개의 중첩 pruning 기법이 적요되어 있음.
list(hook)

[<torch.nn.utils.prune.RandomUnstructured at 0x7fe2aedc5520>,
 <torch.nn.utils.prune.LnStructured at 0x7fe2aedc50d0>]

In [31]:
# Pruning 모델 적용
# state_dict 에 모든 _mask, _orig 등이 있음.
module.state_dict().keys()

odict_keys(['weight_orig', 'bias_orig', 'weight_mask', 'bias_mask'])

In [32]:
# Pruning 영구적용
# =>Pruning 적용을 위해서는 weight_orig, weight_mask, bias_orig, bias_mask 등과 pre_hook 등 모듈들을 제거해야함
# => remove 메소드 사용

prune.remove(module, 'weight')
list(module.named_parameters())

[('bias_orig',
  Parameter containing:
  tensor([ 1.0594e-03,  7.3360e-04,  5.4275e-03, -1.3735e-03,  2.9526e-03,
           5.4779e-03,  2.9630e-03,  6.7515e-03,  4.7061e-03,  3.2703e-03,
          -2.9103e-03,  1.1439e-03, -4.9434e-03,  4.8971e-03,  1.2137e-02,
          -5.1321e-04, -5.4592e-03, -9.2707e-04,  5.4551e-03, -4.9731e-03,
          -6.3760e-03, -2.2797e-03,  9.0375e-03,  6.7416e-03, -5.5720e-03,
          -8.4298e-03,  2.1314e-03, -7.0118e-03,  1.4222e-03, -3.8760e-04,
           4.2457e-04, -2.0165e-03,  4.2753e-03,  5.7197e-03, -7.3704e-03,
           5.8238e-03, -3.5565e-03, -4.7996e-03, -4.2372e-03, -5.0620e-03,
          -1.4990e-02, -1.2310e-03,  1.3331e-03, -4.7811e-03,  3.9264e-03,
          -3.6671e-03, -5.8450e-04,  4.4344e-03, -7.1720e-03,  2.5168e-03,
          -4.0643e-03,  4.9319e-03, -2.3885e-04, -5.4040e-04,  1.5722e-03,
           7.4370e-03,  5.1830e-03, -8.6837e-03,  4.1268e-03, -1.6794e-03,
          -2.8951e-03,  2.9707e-03,  7.9118e-03, -5.1855e-03,

In [33]:
#bias 도 remove 함
prune.remove(module, "bias")
list(module.named_parameters())

[('weight',
  Parameter containing:
  tensor([[ 0.0080, -0.0000, -0.0000,  ...,  0.0000, -0.0000, -0.0000],
          [ 0.0473,  0.0000,  0.0000,  ..., -0.0156,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000,  ..., -0.0036, -0.0000, -0.0526],
          ...,
          [-0.0239, -0.0330,  0.0000,  ...,  0.0000,  0.0000,  0.0018],
          [-0.0000, -0.0816, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0496, -0.0000,  ...,  0.0441, -0.0000,  0.0385]],
         requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([ 0.0011,  0.0007,  0.0054, -0.0014,  0.0030,  0.0055,  0.0030,  0.0068,
           0.0047,  0.0033, -0.0029,  0.0011, -0.0049,  0.0049,  0.0121, -0.0005,
          -0.0055, -0.0009,  0.0055, -0.0050, -0.0064, -0.0023,  0.0090,  0.0067,
          -0.0056, -0.0084,  0.0021, -0.0070,  0.0014, -0.0000,  0.0000, -0.0020,
           0.0043,  0.0057, -0.0074,  0.0058, -0.0036, -0.0048, -0.0042, -0.0051,
          -0.0150, -0.0012,  0.0013, -0.0

In [34]:
# named_buffers() 출력 해봄
list(module.named_buffers())

[]

In [35]:
# _forward_pre_hooks 출력해봄
module._forward_pre_hooks

OrderedDict()

In [37]:
# model 저장해 봄
import os
out_path = '../../../model/bert/bmc-fpt-bong_corpus_mecab-0428-pruning1'
os.makedirs(out_path, exist_ok=True)
model.save_pretrained(out_path)