In [15]:
!pip install torch_pruning
import torch
from torchvision.models import resnet18
from torchvision import datasets, models, transforms
import torch_pruning as tp
import time
import random



In [16]:
# Set Variables
model = resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 9)
example_inputs = torch.randn(1, 3, 224, 224)
loss_function = torch.nn.CrossEntropyLoss()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [17]:
print(device)

cpu


In [18]:
# List of values
ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]

# Get a random value from the list
pruning_ratio = random.choice(ratios)
pruning_percentage = float(pruning_ratio * 100)

print(f"Your pruning percentage for this test is {pruning_percentage} %.")

Your pruning percentage for this test is 90.0 %.


In [19]:
# 1. Importance criterion
imp = tp.importance.GroupTaylorImportance() # or GroupNormImportance(p=2), GroupHessianImportance(), etc.

In [20]:
# 2. Ignore Output Layer for Pruning
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 9:
        ignored_layers.append(m) # DO NOT prune the final classifier!

In [21]:
# 3. Create Meta Pruner Model
pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=pruning_ratio, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

In [22]:
# 4. Print the base number of parameters and MACs
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Base Params in Millions: {base_nparams/1e6}\n")
print(f"Base MACs in Millions: {base_macs/1e6}\n")

Base Params in Millions: 11.181129

Base MACs in Millions: 1821.669385



In [23]:
# 5. Prune
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
if isinstance(imp, tp.importance.GroupTaylorImportance):
    outputs = model(example_inputs)
    target = torch.randint(0, 9, (1,))  # 9 classes for classification
    loss = loss_function(outputs, target)
    loss.backward() # before pruner.step()

In [24]:
# 6. Print the pruned model number of parameters and MACs
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Pruned Params in Millions: {nparams/1e6}\n")
print(f"Pruned MACs in Millions: {macs/1e6}\n")

Pruned Params in Millions: 0.111046

Pruned MACs in Millions: 27.467222



In [25]:
# 7. Set up training and testing directories
train_dir = "train"
val_dir = "val"

In [26]:
# 8. Perform Data Augmentation on train and test data

transforms_train = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.RandomCrop((224, 224)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(), 
    transforms.ColorJitter(brightness=0.1),  
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
])

transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),  
     transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [27]:
# 9. Prepare Dataloaders
train_dataset = datasets.ImageFolder(train_dir, transforms_train)
val_dataset = datasets.ImageFolder(val_dir, transforms_val)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=0)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=12, shuffle=False, num_workers=0)

In [28]:
# Train Model

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

train_loss=[]
train_accuary=[]
val_loss=[]
val_accuary=[]

num_epochs = 15   #(set no of epochs)
start_time = time.time() #(for showing time)

for epoch in range(num_epochs): #(loop for every epoch)
    print("Epoch {} running".format(epoch)) #(printing message)
    """ Training Phase """
    model.train()    #(training model)
    running_loss = 0.   #(set loss 0)
    running_corrects = 0 
    # load a batch data of images
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device) 
        # forward inputs and get output
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = loss_function(outputs, labels)
        # get loss value and update the network weights
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data).item()
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects / len(train_dataset) * 100.
    # Append result
    train_loss.append(epoch_loss)
    train_accuary.append(epoch_acc)
    # Print progress
    print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s\n'.format(epoch+1, epoch_loss, epoch_acc, time.time() -start_time))
    
# Evaluate Model on Test Data

model.eval()
with torch.no_grad():
    running_loss = 0.
    running_corrects = 0
    for inputs, labels in val_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = loss_function(outputs, labels)
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data).item()
    epoch_loss = running_loss / len(val_dataset)
    epoch_acc = running_corrects / len(val_dataset) * 100.
    # Append result
    val_loss.append(epoch_loss)
    val_accuary.append(epoch_acc)
    # Print progress
    print('Test Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s\n'.format(epoch_loss, epoch_acc, time.time()- start_time))

Epoch 0 running
[Train #1] Loss: 0.1821 Acc: 18.7948% Time: 34.5810s

Epoch 1 running
[Train #2] Loss: 0.1707 Acc: 30.4161% Time: 69.0289s

Epoch 2 running
[Train #3] Loss: 0.1625 Acc: 33.7159% Time: 105.2924s

Epoch 3 running
[Train #4] Loss: 0.1535 Acc: 39.7418% Time: 138.5668s

Epoch 4 running
[Train #5] Loss: 0.1487 Acc: 40.3156% Time: 175.0220s

Epoch 5 running
[Train #6] Loss: 0.1401 Acc: 47.6327% Time: 211.7935s

Epoch 6 running
[Train #7] Loss: 0.1381 Acc: 45.9110% Time: 244.4295s

Epoch 7 running
[Train #8] Loss: 0.1313 Acc: 50.2152% Time: 281.2575s

Epoch 8 running
[Train #9] Loss: 0.1260 Acc: 51.2195% Time: 314.7184s

Epoch 9 running
[Train #10] Loss: 0.1219 Acc: 52.9412% Time: 353.8662s

Epoch 10 running
[Train #11] Loss: 0.1225 Acc: 51.2195% Time: 388.7078s

Epoch 11 running
[Train #12] Loss: 0.1156 Acc: 55.0933% Time: 426.9842s

Epoch 12 running
[Train #13] Loss: 0.1136 Acc: 56.0976% Time: 463.5556s

Epoch 13 running
[Train #14] Loss: 0.1106 Acc: 56.3845% Time: 496.3715s
