## 掛載雲端硬碟

In [45]:
# from google.colab import drive
# drive.mount('/content/drive')

## 更改檔案所在路徑


In [1]:
# Change to your own folder !!!
%cd /home/twccjq88/2025EAI_Project/EAI_Lab2

/home/twccjq88/2025EAI_Project/EAI_Lab2


## 載入函式庫


In [2]:
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np

from models.resnet import ResNet50

## 超參數設定

In [3]:
DATASET = 'cifar10'
TEST_BATCH_SIZE = 1000
CUDA = True
PRUNE_PERCENT = 0.90 # Change your prune ratio!
WEIGHT_PATH = '/home/twccjq88/2025EAI_Project/EAI_Lab2/checkpoints/model_lambda_1e-4.pth'
PRUNE_PATH = '/home/twccjq88/2025EAI_Project/EAI_Lab2/checkpoints/model_prune.pth'

## 載入模型

In [4]:
from collections import OrderedDict

CUDA = CUDA and torch.cuda.is_available()

model = ResNet50(num_classes=10)
if CUDA:
    model.cuda()

if WEIGHT_PATH:
    if os.path.isfile(WEIGHT_PATH):
        checkpoint = torch.load(WEIGHT_PATH)
        state_dict = checkpoint['state_dict']

        cleaned_state_dict = OrderedDict()
        for k, v in state_dict.items():
            cleaned_key = k.replace('module.', '', 1) if k.startswith('module.') else k
            cleaned_state_dict[cleaned_key] = v

        model.load_state_dict(cleaned_state_dict)
        best_acc = checkpoint.get('best_test_acc', None)
        epoch = checkpoint.get('epoch', 'N/A')
        print(f'LOADING CHECKPOINT {WEIGHT_PATH} @EPOCH={epoch}, BEST_ACC={best_acc}')
    else:
        print("NO CHECKPOINT FOUND")

print(model)

  checkpoint = torch.load(WEIGHT_PATH)


LOADING CHECKPOINT /home/twccjq88/2025EAI_Project/EAI_Lab2/checkpoints/model_lambda_1e-4.pth @EPOCH=39, BEST_ACC=0.9115
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

## Run Pruning
#### Collect and sort the absolute scale factors from every BatchNorm layer
#### Use the configured PRUNE_PERCENT to pick the threshold

In [5]:
total = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d) and m.weight is not None:
        # The downsample shortcut BatchNorm uses affine=False, so skip it because it has no weights
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d) and m.weight is not None:
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y, i = torch.sort(bn)

threshold_index = min(total - 1, int(total * PRUNE_PERCENT))
threshold = y[threshold_index]


## Build CONFIG From BatchNorm Layers
#### 1. Copy each BatchNorm scale factor (γ)
#### 2. Create a mask that keeps values above the threshold and drops the rest
#### 3. Sum the mask to obtain the remaining output channels per layer
#### 4. Use these values to build the pruned model CONFIG

In [6]:
pruned = 0
cfg = []  #用來建立剪枝網路的CONFIG
cfg_mask = [] #用來幫助剪枝的遮罩

In [7]:
# ===== 第一步：建立初始 mask 和 cfg =====
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d) and m.weight is not None:
        weight_copy = m.weight.data.clone()
        mask = weight_copy.abs().gt(threshold).float().to(weight_copy.device)
        
        if torch.sum(mask) == 0:
            preserve = min(3, mask.numel())
            _, idx = torch.topk(weight_copy.abs(), k=preserve, largest=True, sorted=False)
            mask[idx] = 1.0
        
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.detach().cpu().clone())


# ===== 第二步：固定所有 BN3 到原始維度 =====
original_output_dims = [256, 512, 1024, 2048]
layers = [3, 4, 6, 3]
cfg_idx = 1

for layer_num, num_blocks in enumerate(layers):
    target_dim = original_output_dims[layer_num]
    
    for b in range(num_blocks):
        bn3_idx = cfg_idx + b * 3 + 2
        cfg[bn3_idx] = target_dim
        cfg_mask[bn3_idx] = torch.ones(target_dim)
    
    cfg_idx += num_blocks * 3


# ===== 第三步：計算結果 =====
print("="*70)
print(f'cfg: {cfg}')
print("="*70)

# 建立模型並計算參數
newmodel = ResNet50(num_classes=10, cfg=cfg)
if CUDA:
    newmodel.cuda()

pruned_params = sum(p.numel() for p in newmodel.parameters())
original_params = 23_513_162

print(f"Pruned parameters: {pruned_params:,} ({pruned_params/1e6:.2f}M)")
print(f"Original parameters: {original_params:,} (23.51M)")
print(f"Reduction: {(1 - pruned_params/original_params)*100:.2f}%")
print("="*70)

cfg: [64, 64, 64, 256, 64, 64, 256, 64, 64, 256, 122, 118, 512, 118, 123, 512, 102, 107, 512, 54, 37, 512, 27, 31, 1024, 23, 62, 1024, 3, 3, 1024, 3, 1, 1024, 3, 3, 1024, 3, 3, 1024, 1, 3, 2048, 3, 3, 2048, 3, 3, 2048]
Pruned parameters: 3,973,494 (3.97M)
Original parameters: 23,513,162 (23.51M)
Reduction: 83.10%


## 建立剪枝模型

In [8]:
newmodel = ResNet50(num_classes=10, cfg=cfg)
newmodel.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [9]:
# 檢查參數數量
pruned_params = sum(p.numel() for p in newmodel.parameters())
original_params = 23_513_162

print(f"Pruned parameters: {pruned_params:,} ({pruned_params/1e6:.2f}M)")
print(f"Original parameters: {original_params:,} (23.51M)")
print("="*70)

Pruned parameters: 3,973,494 (3.97M)
Original parameters: 23,513,162 (23.51M)


### 將原本的模型權重複製到剪枝的模型
#### 根據不同層決定要複製什麼權重
###### Batch Normalization Layer
1.   scale factor
2.   bias
3.   running mean
4.   running variance

###### Convolutional Layer
1.   weight

###### Linear Layer
1.   weight
2.   bias




In [10]:
old_modules = list(model.modules())
new_modules = list(newmodel.modules())

layer_id_in_cfg = 0
start_mask = torch.ones(3) #3為input channel(R,G,B)
end_mask = cfg_mask[layer_id_in_cfg]
bn_count = 0
block_input_mask = torch.ones(3)  # 用於記錄每個 block 開始時的輸入通道遮罩（用於 downsample 分支）
block_output_mask = None  # 用於記錄每個 block 結束時的輸出通道遮罩（BN3 的遮罩，用於 downsample 輸出）

for layer_id in range(len(old_modules)):

    m0 = old_modules[layer_id]
    m1 = new_modules[layer_id]

    if isinstance(m0, nn.BatchNorm2d):
        # downsample shortcut 的 BN 設為 affine=False，沒有 weight/bias
        # downsample BN 在 BN3 之後被處理，所以應該用 block_output_mask (BN3 的遮罩)
        if m0.weight is None:
            if block_output_mask is not None:
                idx = torch.nonzero(block_output_mask).squeeze().long().cpu()
                if idx.ndim == 0:
                    idx = idx.unsqueeze(0)
                m1.running_mean = m0.running_mean.cpu().index_select(0, idx).to(m1.running_mean.device).clone()
                m1.running_var = m0.running_var.cpu().index_select(0, idx).to(m1.running_var.device).clone()
            continue

        #### 找出遮罩中非零元素的index ####
        ################################################
        #          請填空          #
        ################################################
        current_mask = cfg_mask[layer_id_in_cfg]
        
        # 在每個 block 的第一個 BN 時（bn_count%3==1），記錄該 block 的輸入通道遮罩
        # 在每個 block 的第三個 BN 時（bn_count%3==0），保存該 block 的輸出遮罩
        # 注意: bn_count 在這裡尚未處理當前 BN,範圍是 [0, 1, 2, 3, 4, 5, ...]
        # bn_count=0: 即將處理 Conv1.BN (0%3=0, 但 layer_id_in_cfg==0 是特例)
        # bn_count=1: 即將處理 Block0.BN1 → 1%3=1 (BN1) ✓ 記錄輸入
        # bn_count=2: 即將處理 Block0.BN2 → 2%3=2 (BN2)
        # bn_count=3: 即將處理 Block0.BN3 → 3%3=0 (BN3) ✓ 記錄輸出
        if bn_count % 3 == 1:  # bn_count=1,4,7,... (block 的第一個 BN)
            if layer_id_in_cfg == 0:
                # 第一個 BN（conv1 後面）使用其自己的輸出作為後續 block 的輸入
                block_input_mask = current_mask.clone()
            else:
                # 後續 block 的輸入是上一個 BN 的輸出
                block_input_mask = start_mask.clone()
        
        # 在 BN3 時（block 的第 3 個 BN），保存該 block 的輸出遮罩
        # 供後續 downsample 層使用
        if bn_count % 3 == 0 and bn_count > 0:  # bn_count=3,6,9,... (block 的第三個 BN，即 BN3)
            block_output_mask = current_mask.clone()
        
        bn_count += 1
        
        idx = torch.nonzero(current_mask).squeeze().long()
        if idx.ndim == 0:
            idx = idx.unsqueeze(0)
        idx = idx.cpu()

        #### 複製weight, bias, running mean,and running variance ####
        ################################################
        #          請填空          #
        ################################################
        m1.weight.data = m0.weight.data.cpu().index_select(0, idx).to(m1.weight.device).clone()
        m1.bias.data = m0.bias.data.cpu().index_select(0, idx).to(m1.bias.device).clone()
        m1.running_mean = m0.running_mean.cpu().index_select(0, idx).to(m1.running_mean.device).clone()
        m1.running_var = m0.running_var.cpu().index_select(0, idx).to(m1.running_var.device).clone()

        layer_id_in_cfg += 1
        start_mask = end_mask.clone()

        #最後一層連接層不做修改
        if layer_id_in_cfg < len(cfg_mask):
            end_mask = cfg_mask[layer_id_in_cfg]
        else:
            end_mask = torch.ones_like(start_mask)


    elif isinstance(m0, nn.Conv2d):
        if isinstance(old_modules[layer_id + 1], nn.BatchNorm2d) and old_modules[layer_id + 1].weight is not None:
            idx0 = torch.nonzero(start_mask).squeeze().long().cpu()
            idx1 = torch.nonzero(end_mask).squeeze().long().cpu()

            #### 複製weight ####
            ################################################
            #          請填空          #
            ################################################
            if idx0.ndim == 0:
                idx0 = idx0.unsqueeze(0)
            if idx1.ndim == 0:
                idx1 = idx1.unsqueeze(0)
            w = m0.weight.data.cpu().index_select(1, idx0)
            w = w.index_select(0, idx1)
            m1.weight.data = w.to(m1.weight.device).clone()



        # downsample 層也需根據剪枝結果調整輸入／輸出通道
        # downsample Conv 在 BN3 之後被處理，所以應該用 block_output_mask (BN3 的遮罩) 作為輸出通道
        # 輸入應該是 block 開始時的通道數 (block_input_mask)
        # 輸出應該是 block 結束時的通道數 (block_output_mask，即 BN3 的遮罩)
        else:
            input_idx = torch.nonzero(block_input_mask).squeeze().long()
            output_idx = torch.nonzero(block_output_mask).squeeze().long()
            if input_idx.ndim == 0:
                input_idx = input_idx.unsqueeze(0)
            if output_idx.ndim == 0:
                output_idx = output_idx.unsqueeze(0)
            input_idx = input_idx.cpu()
            output_idx = output_idx.cpu()
            w = m0.weight.data.cpu().index_select(1, input_idx)
            w = w.index_select(0, output_idx)
            m1.weight.data = w.to(m1.weight.device).clone()


    elif isinstance(m0, nn.Linear):

        idx0 = torch.nonzero(start_mask).squeeze().long().cpu()
        if idx0.ndim == 0:
            idx0 = idx0.unsqueeze(0)

        #### 複製weight ####
        ################################################
        #          請填空          #
        ################################################
        w = m0.weight.data.cpu().index_select(1, idx0)
        m1.weight.data = w.to(m1.weight.device).clone()


        #### 複製bias ####
        m1.bias.data = m0.bias.data.clone()


## 測試函數




In [11]:
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
        batch_size=TEST_BATCH_SIZE, shuffle=True, **kwargs)

    model.eval()
    correct = 0
    with torch.no_grad():
      for data, target in test_loader:
          if CUDA:
              data, target = data.cuda(), target.cuda()
          data, target = Variable(data), Variable(target)
          output = model(data)
          pred = output.data.max(1, keepdim=True)[1]
          correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

## 儲存模型並印出結果，以及剪枝後的test acc


In [12]:
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, PRUNE_PATH)

print(newmodel)
pruned_param_count = sum(p.numel() for p in newmodel.parameters())
print(f"Pruned model parameter count: {pruned_param_count:,}")
model = newmodel.cuda()
test(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

tensor(0.3215)

## ==========================================
## PRUNE_PERCENT = 0.5 版本
## ==========================================

In [13]:
PRUNE_PERCENT_50 = 0.5  # 50% prune ratio
PRUNE_PATH_50 = '/home/twccjq88/2025EAI_Project/EAI_Lab2/checkpoints/model_prune_50.pth'

In [None]:
from collections import OrderedDict

# 重新載入原始模型
model_50 = ResNet50(num_classes=10)
if CUDA:
    model_50.cuda()

if WEIGHT_PATH:
    if os.path.isfile(WEIGHT_PATH):
        checkpoint = torch.load(WEIGHT_PATH)
        state_dict = checkpoint['state_dict']

        cleaned_state_dict = OrderedDict()
        for k, v in state_dict.items():
            cleaned_key = k.replace('module.', '', 1) if k.startswith('module.') else k
            cleaned_state_dict[cleaned_key] = v

        model_50.load_state_dict(cleaned_state_dict)
        print(f'LOADING CHECKPOINT FOR PRUNE_50: {WEIGHT_PATH}')
    else:
        print("NO CHECKPOINT FOUND")

In [None]:
total_50 = 0
for m in model_50.modules():
    if isinstance(m, nn.BatchNorm2d) and m.weight is not None:
        total_50 += m.weight.data.shape[0]

bn_50 = torch.zeros(total_50)
index_50 = 0
for m in model_50.modules():
    if isinstance(m, nn.BatchNorm2d) and m.weight is not None:
        size = m.weight.data.shape[0]
        bn_50[index_50:(index_50+size)] = m.weight.data.abs().clone()
        index_50 += size

y_50, i_50 = torch.sort(bn_50)

threshold_index_50 = min(total_50 - 1, int(total_50 * PRUNE_PERCENT_50))
threshold_50 = y_50[threshold_index_50]
print(f'Threshold for PRUNE_PERCENT=0.5: {threshold_50}')

In [None]:
pruned_50 = 0
cfg_50 = []  #用來建立剪枝網路的CONFIG
cfg_mask_50 = [] #用來幫助剪枝的遮罩

In [None]:
# ===== 第一步：建立初始 mask 和 cfg (PRUNE_PERCENT = 0.5) =====
for k, m in enumerate(model_50.modules()):
    if isinstance(m, nn.BatchNorm2d) and m.weight is not None:
        weight_copy = m.weight.data.clone()
        mask = weight_copy.abs().gt(threshold_50).float().to(weight_copy.device)
        
        if torch.sum(mask) == 0:
            preserve = min(3, mask.numel())
            _, idx = torch.topk(weight_copy.abs(), k=preserve, largest=True, sorted=False)
            mask[idx] = 1.0
        
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        pruned_50 = pruned_50 + mask.shape[0] - torch.sum(mask)
        cfg_50.append(int(torch.sum(mask)))
        cfg_mask_50.append(mask.detach().cpu().clone())


# ===== 第二步：固定所有 BN3 到原始維度 =====
original_output_dims = [256, 512, 1024, 2048]
layers = [3, 4, 6, 3]
cfg_idx = 1

for layer_num, num_blocks in enumerate(layers):
    target_dim = original_output_dims[layer_num]
    
    for b in range(num_blocks):
        bn3_idx = cfg_idx + b * 3 + 2
        cfg_50[bn3_idx] = target_dim
        cfg_mask_50[bn3_idx] = torch.ones(target_dim)
    
    cfg_idx += num_blocks * 3


# ===== 第三步：計算結果 =====
print("="*70)
print(f'cfg_50: {cfg_50}')
print("="*70)

# 建立模型並計算參數
newmodel_50 = ResNet50(num_classes=10, cfg=cfg_50)
if CUDA:
    newmodel_50.cuda()

pruned_params_50 = sum(p.numel() for p in newmodel_50.parameters())
original_params = 23_513_162

print(f"Pruned parameters (50%): {pruned_params_50:,} ({pruned_params_50/1e6:.2f}M)")
print(f"Original parameters: {original_params:,} (23.51M)")
print(f"Reduction: {(1 - pruned_params_50/original_params)*100:.2f}%")
print("="*70)

In [None]:
newmodel_50 = ResNet50(num_classes=10, cfg=cfg_50)
newmodel_50.cuda()

In [None]:
# 檢查參數數量
pruned_params_50 = sum(p.numel() for p in newmodel_50.parameters())
original_params = 23_513_162

print(f"Pruned parameters (50%): {pruned_params_50:,} ({pruned_params_50/1e6:.2f}M)")
print(f"Original parameters: {original_params:,} (23.51M)")
print("="*70)

In [None]:
old_modules_50 = list(model_50.modules())
new_modules_50 = list(newmodel_50.modules())

layer_id_in_cfg_50 = 0
start_mask_50 = torch.ones(3) #3為input channel(R,G,B)
end_mask_50 = cfg_mask_50[layer_id_in_cfg_50]
bn_count_50 = 0
block_input_mask_50 = torch.ones(3)
block_output_mask_50 = None

for layer_id in range(len(old_modules_50)):
    m0 = old_modules_50[layer_id]
    m1 = new_modules_50[layer_id]

    if isinstance(m0, nn.BatchNorm2d):
        if m0.weight is None:
            if block_output_mask_50 is not None:
                idx = torch.nonzero(block_output_mask_50).squeeze().long().cpu()
                if idx.ndim == 0:
                    idx = idx.unsqueeze(0)
                m1.running_mean = m0.running_mean.cpu().index_select(0, idx).to(m1.running_mean.device).clone()
                m1.running_var = m0.running_var.cpu().index_select(0, idx).to(m1.running_var.device).clone()
            continue

        current_mask_50 = cfg_mask_50[layer_id_in_cfg_50]
        
        if bn_count_50 % 3 == 1:
            if layer_id_in_cfg_50 == 0:
                block_input_mask_50 = current_mask_50.clone()
            else:
                block_input_mask_50 = start_mask_50.clone()
        
        if bn_count_50 % 3 == 0 and bn_count_50 > 0:
            block_output_mask_50 = current_mask_50.clone()
        
        bn_count_50 += 1
        
        idx = torch.nonzero(current_mask_50).squeeze().long()
        if idx.ndim == 0:
            idx = idx.unsqueeze(0)
        idx = idx.cpu()

        m1.weight.data = m0.weight.data.cpu().index_select(0, idx).to(m1.weight.device).clone()
        m1.bias.data = m0.bias.data.cpu().index_select(0, idx).to(m1.bias.device).clone()
        m1.running_mean = m0.running_mean.cpu().index_select(0, idx).to(m1.running_mean.device).clone()
        m1.running_var = m0.running_var.cpu().index_select(0, idx).to(m1.running_var.device).clone()

        layer_id_in_cfg_50 += 1
        start_mask_50 = end_mask_50.clone()

        if layer_id_in_cfg_50 < len(cfg_mask_50):
            end_mask_50 = cfg_mask_50[layer_id_in_cfg_50]
        else:
            end_mask_50 = torch.ones_like(start_mask_50)

    elif isinstance(m0, nn.Conv2d):
        if isinstance(old_modules_50[layer_id + 1], nn.BatchNorm2d) and old_modules_50[layer_id + 1].weight is not None:
            idx0 = torch.nonzero(start_mask_50).squeeze().long().cpu()
            idx1 = torch.nonzero(end_mask_50).squeeze().long().cpu()

            if idx0.ndim == 0:
                idx0 = idx0.unsqueeze(0)
            if idx1.ndim == 0:
                idx1 = idx1.unsqueeze(0)
            w = m0.weight.data.cpu().index_select(1, idx0)
            w = w.index_select(0, idx1)
            m1.weight.data = w.to(m1.weight.device).clone()
        else:
            input_idx = torch.nonzero(block_input_mask_50).squeeze().long()
            output_idx = torch.nonzero(block_output_mask_50).squeeze().long()
            if input_idx.ndim == 0:
                input_idx = input_idx.unsqueeze(0)
            if output_idx.ndim == 0:
                output_idx = output_idx.unsqueeze(0)
            input_idx = input_idx.cpu()
            output_idx = output_idx.cpu()
            w = m0.weight.data.cpu().index_select(1, input_idx)
            w = w.index_select(0, output_idx)
            m1.weight.data = w.to(m1.weight.device).clone()

    elif isinstance(m0, nn.Linear):
        idx0 = torch.nonzero(start_mask_50).squeeze().long().cpu()
        if idx0.ndim == 0:
            idx0 = idx0.unsqueeze(0)
        w = m0.weight.data.cpu().index_select(1, idx0)
        m1.weight.data = w.to(m1.weight.device).clone()
        m1.bias.data = m0.bias.data.clone()

In [None]:
torch.save({'cfg': cfg_50, 'state_dict': newmodel_50.state_dict()}, PRUNE_PATH_50)

print(newmodel_50)
pruned_param_count_50 = sum(p.numel() for p in newmodel_50.parameters())
print(f"Pruned model parameter count (50%): {pruned_param_count_50:,}")
model_50_final = newmodel_50.cuda()
test(model_50_final)