In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset, random_split
from torch import nn
import numpy as np
import sys
import matplotlib.pyplot as plt
import math
import time
import random
import sklearn.metrics as perf
import os
import cv2
import torch_pruning as tp
import copy, time
import torch.nn.utils.prune as prune

from models.models import MTLClassifier, AgeRegressor, GenderClassifier, EthnicityClassifier
from utils.data import FacesDataset, data_transform
from utils.training import train_mtl_model, train_age_model, train_gender_model, train_ethnicity_model
from utils.evaluation import run_evaluation, show_example_predictions
use_gpu = False

In [2]:
### Load in the data
folder = 'UTKFace'
transform = data_transform()
dataset = FacesDataset(folder=folder, transform=transform)

In [3]:
### Set up train and val datasets and loaders
train_len = int(len(dataset)*0.8)
val_len = len(dataset) - train_len
train_dataset, val_dataset = random_split(dataset, [train_len, val_len], torch.Generator().manual_seed(8))

train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False)

PRUNING CONV Layers

In [None]:
def prune_model(model, PRUNING_PERCENT=0.2):
    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
    strategy = tp.strategy.L1Strategy() 
    for name, module in model.named_modules():
        # if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
        if isinstance(module, torch.nn.Conv2d):
            pruning_idxs = strategy(module.weight, amount=PRUNING_PERCENT) # or manually selected pruning_idxs=[2, 6, 9, ...]
            pruning_plan = DG.get_pruning_plan(module, tp.prune_conv, idxs=pruning_idxs )

            pruning_plan.exec()
    return model

In [None]:
#FINE TUNING:
def fine_tune(model):
    model = model.cuda()
    num_epochs = 5
    age_coeff = 0.004
    gender_coeff = 2
    ethni_coeff = 1
    age_criterion = nn.MSELoss()
    gender_criterion = nn.CrossEntropyLoss()
    ethni_criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters())
    
    train_mtl_model(num_epochs=num_epochs, model=model, optimizer=optimizer,
                    train_loader=train_loader, val_loader=val_loader,
                    age_criterion=age_criterion, gender_criterion=gender_criterion, ethni_criterion=ethni_criterion,
                    age_coeff=age_coeff, gender_coeff=gender_coeff, ethni_coeff=ethni_coeff, save=False)

In [3]:
def get_inference_time(model):
    if use_gpu:
        model = model.cuda()
    model.eval()
    with torch.no_grad():
        for i, (img, age, gender, ethnicity) in enumerate(val_loader):
              if use_gpu:
                img = img.cuda()
                age = age.float().cuda()
                gender = gender.long().cuda()
                ethnicity = ethnicity.long().cuda()
              
              start = time.time()
              # Get outputs
              age_output, gender_output, ethnicity_output = model(img)
              age_output = age_output.squeeze(1)
              gender_output = gender_output
              ethnicity_output = ethnicity_output

              # Get predictions
              age_pred = age_output
              gender_pred = torch.argmax(gender_output, axis=1)
              ethnicity_pred = torch.argmax(ethnicity_output, axis=1)
              end = time.time()

              inference_latency = end-start
              print("Time to predict:", inference_latency)

              return inference_latency

BASIC PRUNING CODE:

In [None]:
PATH = 'models/mtl_face_model_v1.pt''
ITER_PRUNING = 15
PRUNING_PERCENT = 0.01

model = MTLClassifier()
model.load_state_dict(torch.load('models/mtl_face_model_v1.pt'))
if use_gpu:
  model = model.cuda()
model.eval()

In [None]:
tasks = ['age', 'gender', 'ethnicity']
mtl_model = True
for idx_prune in range(ITER_PRUNING):
        print(f"\n\nIteration {idx_prune+1} - Pruning {PRUNING_PERCENT*100}% of the least important neurons in every conv")
        model1 = copy.deepcopy(model)
        pruned_model = prune_model(model1, PRUNING_PERCENT)
        if use_gpu:
            pruned_model = pruned_model.cuda()
        fine_tune(pruned_model)
        if not use_gpu:
            device = torch.device("cpu")
            pruned_model.to(device)
        pruned_model.eval()
        inf_time = get_inference_time(pruned_model)
        score_dict = run_evaluation(pruned_model, val_loader, tasks, mtl_model)
        
        f = open("pruned_models/model_data.txt", "a")
        s = f'{idx_prune+1},{score_dict["age"][1]},{score_dict["gender"][1]},{score_dict[{"ethnicity"][1],{inf_time}\n'
        f.write(s)
        f.close()
        torch.save(model.state_dict(),f"pruned_models/mtl_face_model_{idx_prune+1}.pt")
        PRUNING_PERCENT+=0.02
