# TRANSTAILOR METHOD

## Import

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import itertools
import pickle


## Load model and dataset
* Model: VGG16
* Dataset: CIFAR10

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE: " + str(device))

# Load the VGG16 model
model = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
batch_size = 64

# Define the data transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the CIFAR10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

kwargs = {'num_workers': 10, 'pin_memory': True} if device == 'cuda' else {}
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, **kwargs)

# Replace the last layer of the model with a new layer that matches the number of classes in CIFAR10
num_classes = 10
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes)

model = model.to(device)

DEVICE: cuda


Files already downloaded and verified


In [3]:
num_epochs = 10
print("total",sum(p.numel() for p in model.parameters()))
print("trainable",sum(p.numel() for p in model.parameters() if p.requires_grad))

total 134301514
trainable 134301514


## Finetune model based on target data CIFAR10

Load the model from previous training

In [4]:
# Load the model's parameters
model.load_state_dict(torch.load('model_parameters_10_epochs.pt'))

<All keys matched successfully>

If we don't find any checkpoint, finetune for 10 epoches

In [3]:
num_epochs = 10;

# Fine-tune the pre-trained model to generate W_s*
print("\n===Fine-tune the pre-trained model to generate W_s*===")
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(num_epochs):
    print("Epoch " + str(epoch) + "/" + str(num_epochs))
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


===Fine-tune the pre-trained model to generate W_s*===
Epoch 0/10


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10


Save model for later use

In [4]:
# Save the model's parameters
torch.save(model.state_dict(), 'model_parameters_10_epochs.pt')

## Define and traing scaling factor `α`

1.   Train the scaling factors using the target data (CIFAR10 in this case).
2.   Transform the scaling factors to the filter importance using the Taylor expansion method.
3.   Prune the filters based on the filter importance.
4.   Fine-tune the pruned model using the target data.



### Init `scaling_factors`

**Load `scaling_factors` (if any)**

In [4]:
with open('scaling_factor_10epoch.pkl', 'rb') as handle:
    scaling_factors = pickle.load(handle)

**Init `scaling_factors`**

In [None]:
num_layers = len(model.features)
scaling_factors = {}

for i in range(num_layers):
    layer = model.features[i]
    if isinstance(layer, torch.nn.Conv2d):
        print(layer,layer.out_channels)
        # num_filters[i] = layer.out_channels
        scaling_factors[i] = torch.rand((1,layer.out_channels,1,1), requires_grad=True)

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 64
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 64
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 128
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 128
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 256
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 256
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 256
Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 512


### Train `scaling_factors` by freezing filters' outputs

In [12]:
num_epochs = 10
# filter_outputs = []
for param in model.parameters():
    param.requires_grad = False
criterion = torch.nn.CrossEntropyLoss()

print("\n===Train the factors alpha by optimizing the loss function===")
params_to_optimize = itertools.chain(scaling_factors[sf] for sf in scaling_factors.keys())
optimizer_alpha = torch.optim.SGD(params_to_optimize, lr=0.001, momentum=0.9)
for epoch in range(num_epochs):
    print("Epoch " + str(epoch) + "/" + str(num_epochs))
    iter_count = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        batch_size = inputs.shape[0]
        optimizer_alpha.zero_grad()
        outputs = inputs
        outputs.requires_grad = False
        for i in range(num_layers):
            if isinstance(model.features[i], torch.nn.Conv2d):
                outputs = model.features[i](outputs)
                outputs = outputs*scaling_factors[i].cuda()
            else:
                outputs = model.features[i](outputs)
        outputs = torch.flatten(outputs, 1)
        classification_output = model.classifier(outputs)
        loss = criterion(classification_output, labels)
        loss.backward()
        optimizer_alpha.step()


===Train the factors alpha by optimizing the loss function===
Epoch 0/10


100%|██████████| 391/391 [03:52<00:00,  1.68it/s]


Epoch 1/10


100%|██████████| 391/391 [03:53<00:00,  1.67it/s]


Epoch 2/10


100%|██████████| 391/391 [03:55<00:00,  1.66it/s]


Epoch 3/10


100%|██████████| 391/391 [03:55<00:00,  1.66it/s]


Epoch 4/10


100%|██████████| 391/391 [03:56<00:00,  1.65it/s]


Epoch 5/10


100%|██████████| 391/391 [03:56<00:00,  1.66it/s]


Epoch 6/10


100%|██████████| 391/391 [03:55<00:00,  1.66it/s]


Epoch 7/10


100%|██████████| 391/391 [03:54<00:00,  1.67it/s]


Epoch 8/10


100%|██████████| 391/391 [03:54<00:00,  1.67it/s]


Epoch 9/10


100%|██████████| 391/391 [03:54<00:00,  1.67it/s]


**Save `scaling_factors`**

In [None]:

with open('scaling_factor_10epoch.pkl', 'wb') as handle:
    pickle.dump(scaling_factors, handle, protocol=pickle.HIGHEST_PROTOCOL)


### Transform `scaling_factors` to `importance_score`

**Only run below cell if you load scaling factor from file**

In [14]:
importance_scores = {}
num_layers = len(model.features)
criterion = torch.nn.CrossEntropyLoss()

for inputs, labels in train_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = inputs
    for i in range(num_layers):
        if isinstance(model.features[i], torch.nn.Conv2d):
            outputs = model.features[i](outputs)
            outputs = outputs*scaling_factors[i].cuda()
        else:
            outputs = model.features[i](outputs)

    outputs = torch.flatten(outputs, 1)
    classification_output = model.classifier(outputs)
    loss = criterion(classification_output, labels)


**Create `importance_scores` using Taylor expansion**

$$\beta_i^j=\left|\frac{\partial\mathcal{L}\left(D_t;W_f^s\bigodot\alpha^*\right)}{\partial\left(\alpha^*\right)_i^j}\left(\alpha^*\right)_i^j\right|$$

In [19]:
for i, scaling_factor in scaling_factors.items():
    first_order_derivative = torch.autograd.grad(loss, scaling_factor, retain_graph=True)[0]
    importance_scores[i] = torch.abs(first_order_derivative * scaling_factor)