# DepGraph Pruning


In [None]:
pip install torch-pruning

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
import matplotlib.pyplot as plt
import random
import numpy as np
import torch_pruning as tp
from google.colab import drive
drive.mount('/content/drive')
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
torch.cuda.manual_seed(42)

##Loading the model's weights

In [None]:
state_dict = torch.load('/content/drive/MyDrive/nat_cifar10_model.pth')
state_dict_adv = torch.load('/content/drive/MyDrive/adv_cifar10_model.pth')

model = resnet18()
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
model.fc = nn.Linear(512, 10)

adv_model = resnet18()
adv_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
adv_model.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
adv_model.fc = nn.Linear(512, 10)

model.load_state_dict(state_dict['model_state_dict'])
adv_model.load_state_dict(state_dict_adv['model_state_dict'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
adv_model.to(device)

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
])
normalizer = transforms.Normalize((0.49139968, 0.48215827 ,0.44653124), (0.24703233, 0.24348505, 0.26158768))
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
def get_example_inputs(data_loader, normalizer=normalizer):
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    images = normalizer(images)
    return images

## Pruning the Standard Model

In [None]:
imp = tp.importance.GroupNormImportance(p=2)
example_inputs = get_example_inputs(train_loader)
example_inputs = example_inputs.to(device)

ignored_layers = []
for m in model.modules():
    if isinstance(m, nn.Linear) and m.out_features == 10:
        ignored_layers.append(m)

pruner = tp.pruner.GroupNormPruner(
    model,
    example_inputs,
    importance = imp,
    pruning_ratio=0.5,
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)

##Fine-tuning for 10 epochs

In [None]:
for epoch in tqdm_notebook(range(1, 11)):
    running_loss = train(model, device, train_loader, optimizer, epoch)
    test_acc = test(model, device, test_loader)
    test_acc_adv = test_adv(model, device, test_adv_loader)
    PATH = "/content/drive/MyDrive/model_nat_pruned.pth"
    torch.save({
        'epoch': epoch,
        'model': model,
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': running_loss,
    }, PATH)
    print(f'\nAttack success rate:{test_acc - test_acc_adv:.2f}%\n')
    lr_sched.step()

In [None]:
model.zero_grad()
torch.save(model, '/content/drive/MyDrive/model_nat_pruned.pth')
loaded_model = torch.load('/content/drive/MyDrive/model_nat_pruned.pth')

##Prunning the Adversarially Robust Model


In [None]:
imp = tp.importance.GroupNormImportance(p=2)
example_inputs = get_example_inputs(train_loader)
example_inputs = example_inputs.to(device)

ignored_layers = []
for m in model.modules():
    if isinstance(m, nn.Linear) and m.out_features == 10:
        ignored_layers.append(m)

pruner = tp.pruner.GroupNormPruner(
    adv_model,
    example_inputs,
    importance = imp,
    pruning_ratio=0.5,
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(adv_model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(adv_model, example_inputs)

##Fine-tuning 10 epochs

In [None]:
for epoch in tqdm_notebook(range(1, 11)):
    running_loss = train_adv(adv_model, device, train_loader, optimizer, epoch)
    test_acc = test(adv_model, device, test_loader)
    test_acc_adv = test_adv(adv_model, device, test_adv_loader)
    PATH = "/content/drive/MyDrive/model_adv_pruned.pth"
    torch.save({
        'epoch': epoch,
        'model': adv_model,
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': running_loss,
    }, PATH)
    print(f'\nAttack success rate:{test_acc - test_acc_adv:.2f}%\n')
    lr_sched.step()

In [None]:
adv_model.zero_grad()
torch.save(adv_model, '/content/drive/MyDrive/model_adv_pruned.pth')
loaded_model = torch.load('/content/drive/MyDrive/model_adv_pruned.pth')