# Structured / Unstructured pruning

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import prune
from torch.utils.data import DataLoader
import copy
import time

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=100, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 10)
model = model.to(device)

def printModelSize(model):
    numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return numParams, float(numParams)*8/pow(2,30)

def apply_pruning(model, method, amount):
    pruned_model = copy.deepcopy(model)
    for name, module in pruned_model.named_modules():
        if hasattr(module, 'weight'):
            if method == 'unstructured':
                prune.l1_unstructured(module, name='weight', amount=amount)
            elif method == 'structured' and isinstance(module, nn.Conv2d):
                if isinstance(module, torch.nn.Conv2d):
                    prune.ln_structured(module, name='weight', amount=amount, n=2, dim=0)

    return pruned_model

def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    return accuracy

def train_model(model, trainloader, testloader, num_epochs=5, lr=0.01):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        epoch_loss = running_loss / len(trainloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    end_time = time.time()
    times = end_time - start_time
    return model, times

ogparam, original_model_params = printModelSize(model)

print("Pruning 진행 하지 않은 모델 학습 시작")
baseline_model, baseline_time = train_model(model, trainloader, testloader, num_epochs=10)
baseline_accuracy = evaluate_model(baseline_model, testloader)

print(f"Baseline Model - Accuracy: {baseline_accuracy:.2f}%")
print(f"Baseline - Training Time: {baseline_time:.2f} seconds")

print("Unstructured Pruning 적용된 모델 학습 시작")
unstructured_model = apply_pruning(model, method='unstructured', amount=0.8)
unstructured_model, unstructured_time = train_model(unstructured_model, trainloader, testloader, num_epochs=10)
unstructured_accuracy = evaluate_model(unstructured_model, testloader)

print(f"Unstructured Pruned Model - Accuracy: {unstructured_accuracy:.2f}%")
print(f"Unstructured Pruning - Training Time: {unstructured_time:.2f} seconds")

for name, module in unstructured_model.named_modules():
    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
        prune.remove(module, 'weight')

unparam, unstructured_model_params = printModelSize(unstructured_model)

print("Structured Pruning 적용된 모델 학습 시작")
structured_model = apply_pruning(model, method='structured', amount=0.8)
structured_model, structured_time = train_model(structured_model, trainloader, testloader, num_epochs=10)
structured_accuracy = evaluate_model(structured_model, testloader)

print(f"Structured Pruned Model - Accuracy: {structured_accuracy:.2f}%")
print(f"Structured Pruning - Training Time: {structured_time:.2f} seconds")

for name, module in structured_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module, 'weight')

stparam, structured_model_params = printModelSize(structured_model)


print(f"{'Model Type':<20}{'Accuracy (%)':<15}{'Training Time (s)':<20}{'Parameters (GB)'}")
print("-" * 70)
print(f"{'Baseline Model':<20}{baseline_accuracy:<15.2f}{baseline_time:<20.2f}{original_model_params:<.3f}")
print(f"{'Unstructured Pruning':<20}{unstructured_accuracy:<15.2f}{unstructured_time:<20.2f}{unstructured_model_params:<.3f}")
print(f"{'Structured Pruning':<20}{structured_accuracy:<15.2f}{structured_time:<20.2f}{structured_model_params:<.3f}")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:08<00:00, 20.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 155MB/s]


Pruning 진행 하지 않은 모델 학습 시작
Epoch [1/10], Loss: 1.0442
Epoch [2/10], Loss: 0.7120
Epoch [3/10], Loss: 0.6225
Epoch [4/10], Loss: 0.5639
Epoch [5/10], Loss: 0.5309
Epoch [6/10], Loss: 0.4978
Epoch [7/10], Loss: 0.4803
Epoch [8/10], Loss: 0.4485
Epoch [9/10], Loss: 0.4329
Epoch [10/10], Loss: 0.4122
Baseline Model - Accuracy: 79.57%
Baseline - Training Time: 351.34 seconds
Unstructured Pruning 적용된 모델 학습 시작
Epoch [1/10], Loss: 1.8199
Epoch [2/10], Loss: 1.4524
Epoch [3/10], Loss: 1.2975
Epoch [4/10], Loss: 1.2193
Epoch [5/10], Loss: 1.1708
Epoch [6/10], Loss: 1.1255
Epoch [7/10], Loss: 1.0907
Epoch [8/10], Loss: 1.0675
Epoch [9/10], Loss: 1.0341
Epoch [10/10], Loss: 1.0238
Unstructured Pruned Model - Accuracy: 63.29%
Unstructured Pruning - Training Time: 368.66 seconds
Structured Pruning 적용된 모델 학습 시작
Epoch [1/10], Loss: 1.5093
Epoch [2/10], Loss: 1.1851
Epoch [3/10], Loss: 1.0723
Epoch [4/10], Loss: 1.0091
Epoch [5/10], Loss: 0.9525
Epoch [6/10], Loss: 0.9216
Epoch [7/10], Loss: 0.8908
Epoc

### 여기부턴 mask가 실제로 어떻게 씌워지는지 시각화를 해보았습니다.

pruning전 모델 구조 확인

In [4]:
# 2. ResNet-18 모델 준비
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torchvision.models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 10)  # CIFAR-10에 맞게 출력 수정
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 197MB/s]


In [5]:
module = model.conv1

In [6]:
# prune을 적용할 모듈에 대해
module = model.conv1  # 예시로 모델의 첫 번째 conv 계층

if hasattr(module, 'weight'):
    prune.l1_unstructured(module, name='weight', amount=0.5)

if hasattr(module, 'bias') and module.bias is not None:
    prune.l1_unstructured(module, name='bias', amount=0.5)

In [7]:
module.weight

tensor([[[[-0.0000, -0.0000, -0.0000,  ...,  0.0566,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.1099,  ..., -0.2712, -0.1291,  0.0000],
          [-0.0000,  0.0591,  0.2955,  ...,  0.5197,  0.2563,  0.0636],
          ...,
          [-0.0000,  0.0000,  0.0726,  ..., -0.3328, -0.4206, -0.2578],
          [ 0.0000,  0.0410,  0.0628,  ...,  0.4138,  0.3936,  0.1661],
          [-0.0000, -0.0000, -0.0000,  ..., -0.1507, -0.0822, -0.0000]],

         [[-0.0000, -0.0000, -0.0000,  ...,  0.0000,  0.0000, -0.0000],
          [ 0.0457,  0.0000, -0.1045,  ..., -0.3125, -0.1605, -0.0000],
          [-0.0000,  0.0984,  0.4021,  ...,  0.7079,  0.3689,  0.1246],
          ...,
          [-0.0559, -0.0000,  0.0000,  ..., -0.4618, -0.5708, -0.3655],
          [ 0.0000,  0.0556,  0.0997,  ...,  0.5464,  0.4828,  0.1987],
          [ 0.0000,  0.0000, -0.0000,  ..., -0.1482, -0.0772,  0.0000]],

         [[-0.0000, -0.0000,  0.0000,  ...,  0.0892,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -

structured pruning

In [None]:
module = model.conv1

if hasattr(module, 'weight'):
  prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

In [None]:
print(module.weight)

tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  ...,  5.6615e-02,
            1.7083e-02, -1.2694e-02],
          [ 1.1083e-02,  9.5276e-03, -1.0993e-01,  ..., -2.7124e-01,
           -1.2907e-01,  3.7424e-03],
          [-6.9434e-03,  5.9089e-02,  2.9548e-01,  ...,  5.1972e-01,
            2.5632e-01,  6.3573e-02],
          ...,
          [-2.7535e-02,  1.6045e-02,  7.2595e-02,  ..., -3.3285e-01,
           -4.2058e-01, -2.5781e-01],
          [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  ...,  4.1384e-01,
            3.9359e-01,  1.6606e-01],
          [-1.3736e-02, -3.6746e-03, -2.4084e-02,  ..., -1.5070e-01,
           -8.2230e-02, -5.7828e-03]],

         [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  ...,  3.2521e-02,
            6.6221e-04, -2.5743e-02],
          [ 4.5687e-02,  3.3603e-02, -1.0453e-01,  ..., -3.1253e-01,
           -1.6051e-01, -1.2826e-03],
          [-8.3730e-04,  9.8420e-02,  4.0210e-01,  ...,  7.0789e-01,
            3.6887e-01,  1.2455e-01],
          ...,
     