In [4]:
import torch
import torch.nn.utils.prune as prune

from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection

###  Target network: Visual Encoder
- Num of params : 88M

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
# text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")


Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModelWithProjection: ['text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.laye

In [18]:
sum(p.numel() for p in vision_model.parameters() if p.requires_grad) # 88M

87849216

In [19]:
# print(vision_model)

# CLIPVisionModelWithProjection(
#   (vision_model): CLIPVisionTransformer(
#     (embeddings): CLIPVisionEmbeddings(
#       (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
#       (position_embedding): Embedding(50, 768)
#     )
#     (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
#     (encoder): CLIPEncoder(
#       (layers): ModuleList(
#         (0-11): 12 x CLIPEncoderLayer(
#           (self_attn): CLIPAttention(
#             (k_proj): Linear(in_features=768, out_features=768, bias=True)
#             (v_proj): Linear(in_features=768, out_features=768, bias=True)
#             (q_proj): Linear(in_features=768, out_features=768, bias=True)
#             (out_proj): Linear(in_features=768, out_features=768, bias=True)
#           )
#           (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
#           (mlp): CLIPMLP(
#             (activation_fn): QuickGELUActivation()
#             (fc1): Linear(in_features=768, out_features=3072, bias=True)
#             (fc2): Linear(in_features=3072, out_features=768, bias=True)
#           )
#           (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
#         )
#       )
#     )
#     (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
#   )
#   (visual_projection): Linear(in_features=768, out_features=512, bias=False)
# )


### Pruning
- UnStructured pruning
    - 장점: 모델 사이즈 감소, 속도 향상
    - 단점: 성능 감소
- Structured pruning
    - 장점: 모델 사이즈 감소, 적은 성능 감소
    - 단점: 속도 향상 없음 

-> Structed pruning으로 접근 필요

In [24]:
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")

Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModelWithProjection: ['text_model.encoder.layers.9.layer_norm1.bias', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.laye

In [33]:
vision_model.vision_model.encoder.layers[0].self_attn

CLIPAttention(
  (k_proj): Linear(in_features=768, out_features=768, bias=True)
  (v_proj): Linear(in_features=768, out_features=768, bias=True)
  (q_proj): Linear(in_features=768, out_features=768, bias=True)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
)

In [35]:
parameters_to_prune = ()
for i in range(12):
    parameters_to_prune += (
        (vision_model.vision_model.encoder.layers[0].self_attn.k_proj, 'weight'),
        (vision_model.vision_model.encoder.layers[0].self_attn.v_proj, 'weight'),
        (vision_model.vision_model.encoder.layers[0].self_attn.q_proj, 'weight'),
        (vision_model.vision_model.encoder.layers[0].self_attn.out_proj, 'weight'),
    )

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

sum(p.numel() for p in vision_model.parameters() if p.requires_grad) # 88M -> 

87849216

### ResNet pruning

In [None]:
ResNet_Encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")

### Model profiling

In [20]:
from torch.profiler import profile, record_function, ProfilerActivity

inputs = torch.randn(5, 3, 224, 224)

# ProfilerActivity.CPU 
# ProfilerActivity.CUDA
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        vision_model(inputs)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))


STAGE:2023-07-17 05:08:20 31944:31944 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
             model_inference         3.69%       7.565ms       100.00%     204.839ms     204.839ms             1  
                aten::linear         0.44%     892.000us        74.57%     152.756ms       2.093ms            73  
                 aten::addmm        64.99%     133.133ms        72.74%     149.002ms       2.069ms            72  
                 aten::copy_         8.75%      17.923ms         8.75%      17.923ms     123.607us           145  
            aten::layer_norm         0.05%     109.000us         5.18%      10.611ms     408.115us            26  
     aten::native_layer_norm         5.03%      10.313ms         5.13%      10.5

STAGE:2023-07-17 05:08:20 31944:31944 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-17 05:08:20 31944:31944 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [21]:
with profile(activities=[ProfilerActivity.CPU],
        profile_memory=True, record_shapes=True) as prof:
    vision_model(inputs)

print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))

STAGE:2023-07-17 05:08:21 31944:31944 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-07-17 05:08:22 31944:31944 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-17 05:08:22 31944:31944 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                 aten::addmm        66.09%     131.019ms        71.71%     142.174ms       1.975ms      79.10 Mb      79.10 Mb            72  
                   aten::mul         5.92%      11.730ms         6.19%      12.273ms     340.917us      79.10 Mb      79.10 Mb            36  
                 aten::empty         0.22%     439.000us         0.22%     439.000us       3.377us      53.52 Mb      53.52 Mb           130  
               aten::sigmoid         3.23%       6.395ms         3.23%       6.395ms     532.917us      35.16 Mb      35.16 Mb            12  