# TRANSTAILOR METHOD

## Import

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import itertools
import pickle

## Load model and dataset
* Model: VGG16
* Dataset: CIFAR10

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE: " + str(device))

# Load the VGG16 model
model = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
batch_size = 64

# Define the data transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the CIFAR10 train_dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

kwargs = {'num_workers': 12, 'pin_memory': True} if device == 'cuda' else {}
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, **kwargs)

# Load test_dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

kwargs = {'num_workers': 12, 'pin_memory': True} if device == 'cuda' else {}
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, **kwargs)

# Replace the last layer of the model with a new layer that matches the number of classes in CIFAR10
num_classes = 10
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes)

model = model.to(device)

DEVICE: cuda


Files already downloaded and verified


In [3]:
num_epochs = 10
print("total",sum(p.numel() for p in model.parameters()))
print("trainable",sum(p.numel() for p in model.parameters() if p.requires_grad))

total 134301514
trainable 134301514


## Target-aware pruning
Finetune model based on target data CIFAR10

### Load the model from previous training

In [4]:
# Load the model's parameters
model.load_state_dict(torch.load('model_parameters_10_epochs.pt'))

<All keys matched successfully>

### If we don't find any checkpoint, finetune
- In development phase, I finetune in 10 epochs
- In the paper, authors finetune until 10% FLOPs of pre-trained model is pruned

In [4]:
num_epochs = 10;

# Fine-tune the pre-trained model to generate W_s*
print("\n===Fine-tune the pre-trained model to generate W_s*===")
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(num_epochs):
    print("Epoch " + str(epoch) + "/" + str(num_epochs))
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


===Fine-tune the pre-trained model to generate W_s*===
Epoch 0/10


100%|██████████| 782/782 [03:27<00:00,  3.77it/s]


Epoch 1/10


100%|██████████| 782/782 [03:30<00:00,  3.71it/s]


Epoch 2/10


100%|██████████| 782/782 [03:33<00:00,  3.67it/s]


Epoch 3/10


100%|██████████| 782/782 [03:31<00:00,  3.70it/s]


Epoch 4/10


100%|██████████| 782/782 [03:31<00:00,  3.70it/s]


Epoch 5/10


100%|██████████| 782/782 [03:33<00:00,  3.67it/s]


Epoch 6/10


100%|██████████| 782/782 [03:32<00:00,  3.68it/s]


Epoch 7/10


100%|██████████| 782/782 [03:32<00:00,  3.67it/s]


Epoch 8/10


100%|██████████| 782/782 [03:33<00:00,  3.66it/s]


Epoch 9/10


100%|██████████| 782/782 [03:32<00:00,  3.68it/s]


Save model for later use

In [5]:
# Save the model's parameters
torch.save(model.state_dict(), 'model_parameters_10_epochs.pt')

### Create `scaling_factors` $\alpha$

1.   Train the scaling factors $\alpha$ using the target data (CIFAR10 in this case).
2.   Transform the scaling factors $\alpha$ to the filter importance score $\beta$ using the Taylor expansion method.
3.   Prune the filters based on the filter importance.
4.   Fine-tune the pruned model using the target data.



#### Load `scaling_factors` $\alpha$ (if any)

In [5]:
with open('scaling_factor_10epoch.pkl', 'rb') as handle:
    scaling_factors = pickle.load(handle)

#### Train `scaling_factors` $\alpha$
Only run this if you do not load $\alpha$ from file

**Initialize `scaling_factors` $\alpha$**

In [None]:
num_layers = len(model.features)
scaling_factors = {}

for i in range(num_layers):
    layer = model.features[i]
    if isinstance(layer, torch.nn.Conv2d):
        print(layer,layer.out_channels)
        # num_filters[i] = layer.out_channels
        scaling_factors[i] = torch.rand((1,layer.out_channels,1,1), requires_grad=True)

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 64
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 64
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 128
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 128
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 256
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 256
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 256
Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512


**Train `scaling_factors` by freezing filters' outputs**

In the paper, authors had trained for 60 epochs (on VGG16)

In [12]:
num_epochs = 10
# filter_outputs = []
for param in model.parameters():
    param.requires_grad = False
criterion = torch.nn.CrossEntropyLoss()

print("\n===Train the factors alpha by optimizing the loss function===")
params_to_optimize = itertools.chain(scaling_factors[sf] for sf in scaling_factors.keys())
optimizer_alpha = torch.optim.SGD(params_to_optimize, lr=0.001, momentum=0.9)
for epoch in range(num_epochs):
    print("Epoch " + str(epoch) + "/" + str(num_epochs))
    iter_count = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        batch_size = inputs.shape[0]
        optimizer_alpha.zero_grad()
        outputs = inputs
        outputs.requires_grad = False
        for i in range(num_layers):
            if isinstance(model.features[i], torch.nn.Conv2d):
                outputs = model.features[i](outputs)
                outputs = outputs*scaling_factors[i].cuda()
            else:
                outputs = model.features[i](outputs)
        outputs = torch.flatten(outputs, 1)
        classification_output = model.classifier(outputs)
        loss = criterion(classification_output, labels)
        loss.backward()
        optimizer_alpha.step()


===Train the factors alpha by optimizing the loss function===
Epoch 0/10


100%|██████████| 391/391 [03:52<00:00,  1.68it/s]


Epoch 1/10


100%|██████████| 391/391 [03:53<00:00,  1.67it/s]


Epoch 2/10


100%|██████████| 391/391 [03:55<00:00,  1.66it/s]


Epoch 3/10


100%|██████████| 391/391 [03:55<00:00,  1.66it/s]


Epoch 4/10


100%|██████████| 391/391 [03:56<00:00,  1.65it/s]


Epoch 5/10


100%|██████████| 391/391 [03:56<00:00,  1.66it/s]


Epoch 6/10


100%|██████████| 391/391 [03:55<00:00,  1.66it/s]


Epoch 7/10


100%|██████████| 391/391 [03:54<00:00,  1.67it/s]


Epoch 8/10


100%|██████████| 391/391 [03:54<00:00,  1.67it/s]


Epoch 9/10


100%|██████████| 391/391 [03:54<00:00,  1.67it/s]


**Save `scaling_factors`** $\alpha$

In [None]:

with open('scaling_factor_10epoch.pkl', 'wb') as handle:
    pickle.dump(scaling_factors, handle, protocol=pickle.HIGHEST_PROTOCOL)


#### Transform `scaling_factors` $\alpha$ to `importance_score` $\beta$

**Only run below cell if you load scaling factor $\alpha$ from file**. If you have manually trained $\alpha$ above, **skip this step**

In [6]:
importance_scores = {}
num_layers = len(model.features)
criterion = torch.nn.CrossEntropyLoss()

for inputs, labels in train_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = inputs
    for i in range(num_layers):
        if isinstance(model.features[i], torch.nn.Conv2d):
            outputs = model.features[i](outputs)
            outputs = outputs*scaling_factors[i].cuda()
        else:
            outputs = model.features[i](outputs)

    outputs = torch.flatten(outputs, 1)
    classification_output = model.classifier(outputs)
    loss = criterion(classification_output, labels)


**Create `importance_scores` $\beta$ using Taylor expansion**

$$\beta_i^j=\left|\frac{\partial\mathcal{L}\left(D_t;W_f^s\bigodot\alpha^*\right)}{\partial\left(\alpha^*\right)_i^j}\left(\alpha^*\right)_i^j\right|$$

**Load `importance_scores` $\beta$ from file**

In [None]:
with open('importance_scores.pkl', 'rb') as handle:
    importance_scores = pickle.load(handle)

**Transform `importance_scores` $\beta$ from `scaling_factors` $\alpha$**

Do not run below cell if you have loaded $\beta$ from file

In [14]:
for i, scaling_factor in scaling_factors.items():
    first_order_derivative = torch.autograd.grad(loss, scaling_factor, retain_graph=True)[0]
    importance_scores[i] = torch.abs(first_order_derivative * scaling_factor).detach() #Freeze importance_scores[i] after calculating

**Save `importance_scores` $\beta$ from file**

In [8]:
with open('importance_scores.pkl', 'wb') as handle:
    pickle.dump(importance_scores, handle, protocol=pickle.HIGHEST_PROTOCOL)


## Importance-aware Fine-tuning

### Prune the model by using importance_scores $\beta$

In [71]:
# Define function to find the filter with lowest importance_score
def find_filter_to_prune(importance_scores, pruned_filters):
    min_value = float('inf')
    min_filter = None
    min_layer = None

    for layer_index, scores_tensor in importance_scores.items():
        for filter_index, score in enumerate(scores_tensor[0]):
            # Check if the filter has already been pruned
            if (layer_index, filter_index) in pruned_filters:
                continue
            
            if score < min_value:
                min_value = score.item()
                min_filter = filter_index
                min_layer = layer_index
                if min_value == 0:
                    break

    return min_layer, min_filter

In [103]:
# Assuming model is the original VGG-16 model
total_params_original = sum(p.numel() for p in model.parameters())
print('Number of parameters in the original model: {:,}'.format(total_params_original))

# Initialize pruned_filters set if it's not already initialized
if 'pruned_filters' not in locals():
    pruned_filters = set()

# Call the function to find the next filter to prune
layer_to_prune, filter_to_prune = find_filter_to_prune(importance_scores, pruned_filters)

# Now you have the layer and filter index to prune
print("Next filter to prune - Layer:", layer_to_prune, "Filter:", filter_to_prune)

pruned_model = model

pruned_layer = pruned_model.features[layer_to_prune]
pruned_filter = pruned_layer.weight.data[filter_to_prune]

with torch.no_grad():
    pruned_layer.weight.data[filter_to_prune] = 0
    pruned_layer.bias.data[filter_to_prune] = 0

# After pruning, you can update the pruned_filters set
pruned_filters.add((layer_to_prune, filter_to_prune))

Number of parameters in the original model: 134,301,514
Next filter to prune - Layer: 17 Filter: 148


In [104]:
def count_non_zero_parameters(model):
    non_zero_params = sum(p.nonzero().size(0) for p in model.parameters() if p.requires_grad)
    return non_zero_params

non_zero_params_after_pruning = count_non_zero_parameters(pruned_model)
print("Number of non-zero parameters after pruning: {:,}".format(non_zero_params_after_pruning))


Number of non-zero parameters after pruning: 134,292,294


### Finetuning the model after pruning

In [80]:
num_epochs = 1
for param in pruned_model.parameters():
    param.requires_grad = True
criterion = torch.nn.CrossEntropyLoss()

print("\n===Train the factors alpha by optimizing the loss function===")
# params_to_optimize = itertools.chain(scaling_factors[sf] for sf in scaling_factors.keys())
# optimizer_alpha = torch.optim.SGD(params_to_optimize, lr=0.001, momentum=0.9)

optimizer = torch.optim.SGD(pruned_model.parameters(), lr=0.001, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    print("Epoch " + str(epoch) + "/" + str(num_epochs))
    iter_count = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        batch_size = inputs.shape[0]
        optimizer.zero_grad()
        outputs = inputs
        # outputs.requires_grad = True
        for i in range(num_layers):
            if isinstance(model.features[i], torch.nn.Conv2d):
                outputs = model.features[i](outputs)
                outputs = outputs*importance_scores[i].cuda()
            else:
                outputs = model.features[i](outputs)
        outputs = torch.flatten(outputs, 1)
        classification_output = model.classifier(outputs)
        loss = criterion(classification_output, labels)
        loss.backward(retain_graph=True)
        optimizer.step()


===Train the factors alpha by optimizing the loss function===
Epoch 0/1


100%|██████████| 782/782 [03:42<00:00,  3.51it/s]


### Calculate the accuracy of the model after pruning

In [82]:
model.eval()
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    accuracy = 100 * class_correct[i] / class_total[i]
    print('Accuracy for class {}: {:.2f}%'.format(i, accuracy))

Accuracy for class 0: 97.30%
Accuracy for class 1: 98.00%
Accuracy for class 2: 89.30%
Accuracy for class 3: 72.90%
Accuracy for class 4: 94.20%
Accuracy for class 5: 88.60%
Accuracy for class 6: 98.60%
Accuracy for class 7: 97.00%
Accuracy for class 8: 96.90%
Accuracy for class 9: 92.30%
