In [3]:
import os
import random
import yaml
from tqdm import tqdm
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data
from dataset.fairfd import FairFD
from detectors import DETECTOR

detector_name = "spsl"
# detector_name = "ffd"
args = {"detector_path": f"./config/detector/{detector_name}.yaml"}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def init_seed(config):
    if config['manualSeed'] is None:
        config['manualSeed'] = random.randint(1, 10000)
    random.seed(config['manualSeed'])
    torch.manual_seed(config['manualSeed'])
    if config['cuda']:
        torch.cuda.manual_seed_all(config['manualSeed'])

def prepare_my_testing_data(config, root_path):
    paths = ["data/Asian", "data/Caucasian", "data/Indian", "data/African"]
    test_data_loaders = {}
    for i in range(len(paths)):
        test_set = FairFD(config, os.path.join(root_path, paths[i]))
        test_data_loader = \
            torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=config['test_batchSize'],
                shuffle=True,
                num_workers=int(config['workers']),
                collate_fn=test_set.collate_fn,
            )
        test_data_loaders[paths[i]] = test_data_loader
    return test_data_loaders

@torch.no_grad()
def inference(model, data_dict):
    predictions = model(data_dict, inference=True)
    return predictions

In [4]:
# parse options and load config
with open(args["detector_path"], 'r') as f:
    config = yaml.safe_load(f)
# print configuration
print("--------------- Configuration ---------------")
params_string = "Parameters: \n"
for key, value in config.items():
    params_string += "{}: {}".format(key, value) + "\n"
# init seed
init_seed(config)
# set cudnn benchmark if needed
if config['cudnn']:
    cudnn.benchmark = True
# prepare the testing data loaders
rootpath = "../dataset/test"
test_data_loaders = prepare_my_testing_data(config, rootpath)
# prepare the model (detector)
model_class = DETECTOR[config['model_name']]
model = model_class(config).to(device)
weights_path = f"../weights/{detector_name}.pth"
ckpt = torch.load(weights_path, map_location=device)
model.load_state_dict(ckpt, strict=True)
model = model.to(device)
model.eval()
print('===> Load checkpoint done!')

--------------- Configuration ---------------
Loading data from FairFD ...
real: 9688
Data from '../dataset/test/data/Asian' loaded.
Dataset contains 9688 images.

Loading data from FairFD ...
real: 10196
Data from '../dataset/test/data/Caucasian' loaded.
Dataset contains 10196 images.

Loading data from FairFD ...
real: 10308
Data from '../dataset/test/data/Indian' loaded.
Dataset contains 10308 images.

Loading data from FairFD ...
real: 10415
Data from '../dataset/test/data/African' loaded.
Dataset contains 10415 images.

===> Load checkpoint done!


In [6]:
import os
import torch
import pickle
import numpy as np
import yaml

filepath1 = os.path.join("./saved_activations", f"{detector_name}_Caucasian_all_outputs.pkl")
filepath2 = os.path.join("./saved_activations", f"{detector_name}_Asian_all_outputs.pkl")
filepath3 = os.path.join("./saved_activations", f"{detector_name}_African_all_outputs.pkl")
filepath4 = os.path.join("./saved_activations", f"{detector_name}_Indian_all_outputs.pkl")

if os.path.exists(filepath1) and os.path.exists(filepath2) and os.path.exists(filepath3) and os.path.exists(filepath4):
    with open(filepath1, 'rb') as f:
        Caucasian_all_outputs = pickle.load(f)
    with open(filepath2, 'rb') as f:
        Asian_all_outputs = pickle.load(f)
    with open(filepath3, 'rb') as f:
        African_all_outputs = pickle.load(f)
    with open(filepath4, 'rb') as f:
        Indian_all_outputs = pickle.load(f)
else:
    outputs = {}
    def hook_fn(module, input, output):
        class_name = module.__class__.__name__
        module_idx = len(outputs)
        m_key = f'{class_name}_{module_idx + 1}'
        outputs[m_key] = output
    
    handles = {}
    def register_hooks(model):
        for module_name, module in model.named_modules():
            if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
                handles[module_name] = module.register_forward_hook(hook_fn)
    register_hooks(model)
    
    Caucasian_all_outputs = {module_name: {'mean': None, 'count': 0} for module_name in handles.keys()}
    Asian_all_outputs = {module_name: {'mean': None, 'count': 0} for module_name in handles.keys()}
    African_all_outputs = {module_name: {'mean': None, 'count': 0} for module_name in handles.keys()}
    Indian_all_outputs = {module_name: {'mean': None, 'count': 0} for module_name in handles.keys()}
    
    for test_loader_name, test_loader in test_data_loaders.items():
        with torch.no_grad():
            for data_dict in tqdm(test_loader):
                data_dict['image'], data_dict['label'] = data_dict['image'].to(device), data_dict['label'].to(device)
                temp_output = model(data_dict)
                
                if "Caucasian" in test_loader_name:
                    for module_name, key in zip(Caucasian_all_outputs.keys(), outputs.keys()):
                        temp_data = outputs[key].cpu().detach().numpy()
                        if Caucasian_all_outputs[module_name]['mean'] is None:
                            Caucasian_all_outputs[module_name]['mean'] = np.mean(temp_data, axis=0)
                            Caucasian_all_outputs[module_name]['count'] += temp_data.shape[0]
                        else:
                            old_mean = Caucasian_all_outputs[module_name]['mean']
                            count = Caucasian_all_outputs[module_name]['count']
                            new_count = temp_data.shape[0]
                            new_mean = ((count * old_mean) + np.sum(temp_data, axis=0)) / (count + new_count)
                            Caucasian_all_outputs[module_name]['mean'] = new_mean
                            Caucasian_all_outputs[module_name]['count'] += new_count
                elif "Asian" in test_loader_name:
                    for module_name, key in zip(Asian_all_outputs.keys(), outputs.keys()):
                        temp_data = outputs[key].cpu().detach().numpy()
                        if Asian_all_outputs[module_name]['mean'] is None:
                            Asian_all_outputs[module_name]['mean'] = np.mean(temp_data, axis=0)
                            Asian_all_outputs[module_name]['count'] += temp_data.shape[0]
                        else:
                            old_mean = Asian_all_outputs[module_name]['mean']
                            count = Asian_all_outputs[module_name]['count']
                            new_count = temp_data.shape[0]
                            new_mean = ((count * old_mean) + np.sum(temp_data, axis=0)) / (count + new_count)
                            Asian_all_outputs[module_name]['mean'] = new_mean
                            Asian_all_outputs[module_name]['count'] += new_count
                elif "African" in test_loader_name:
                    for module_name, key in zip(African_all_outputs.keys(), outputs.keys()):
                        temp_data = outputs[key].cpu().detach().numpy()
                        if African_all_outputs[module_name]['mean'] is None:
                            African_all_outputs[module_name]['mean'] = np.mean(temp_data, axis=0)
                            African_all_outputs[module_name]['count'] += temp_data.shape[0]
                        else:
                            old_mean = African_all_outputs[module_name]['mean']
                            count = African_all_outputs[module_name]['count']
                            new_count = temp_data.shape[0]
                            new_mean = ((count * old_mean) + np.sum(temp_data, axis=0)) / (count + new_count)
                            African_all_outputs[module_name]['mean'] = new_mean
                            African_all_outputs[module_name]['count'] += new_count
                elif "Indian" in test_loader_name:
                    for module_name, key in zip(Indian_all_outputs.keys(), outputs.keys()):
                        temp_data = outputs[key].cpu().detach().numpy()
                        if Indian_all_outputs[module_name]['mean'] is None:
                            Indian_all_outputs[module_name]['mean'] = np.mean(temp_data, axis=0)
                            Indian_all_outputs[module_name]['count'] += temp_data.shape[0]
                        else:
                            old_mean = Indian_all_outputs[module_name]['mean']
                            count = Indian_all_outputs[module_name]['count']
                            new_count = temp_data.shape[0]
                            new_mean = ((count * old_mean) + np.sum(temp_data, axis=0)) / (count + new_count)
                            Indian_all_outputs[module_name]['mean'] = new_mean
                            Indian_all_outputs[module_name]['count'] += new_count
                else:
                    print("No such race")
                    exit(1)
                # Clear Hook Output
                outputs = {}
                
    for module_name, _ in handles.items():
        handles[module_name].remove()
        
    with open(filepath1, 'wb') as f:
        pickle.dump(Caucasian_all_outputs, f)
    with open(filepath2, 'wb') as f:
        pickle.dump(Asian_all_outputs, f)
    with open(filepath3, 'wb') as f:
        pickle.dump(African_all_outputs, f)
    with open(filepath4, 'wb') as f:
        pickle.dump(Indian_all_outputs, f)

In [7]:
# method = "weight"
# method = "activation"
method = "weight_activation"

importance_scores = {}
for key in Caucasian_all_outputs.keys():
    try:
        arrays = [Caucasian_all_outputs[key]['mean'], Asian_all_outputs[key]['mean'], African_all_outputs[key]['mean'], Indian_all_outputs[key]['mean']]
        l2_norms = [np.linalg.norm(arr, axis=(1, 2)) for arr in arrays]
        l2_norms_stacked = np.stack(l2_norms, axis=-1)
        std_l2_norms = np.std(l2_norms_stacked, axis=1)
        importance_scores[key] = std_l2_norms
    except:
        print("Error:", key)
        importance_scores[key] = 0.0000000001

detector_path = args["detector_path"]
with open(detector_path, 'r') as f:
    config = yaml.safe_load(f)
model_weights = torch.load(f"../weights/{detector_name}.pth", map_location=device)

absolute_product = {}
for key in model_weights.keys():
    try:
        if key in [onekey+".weight" for onekey in importance_scores.keys()]:
            if method == "weight":
                abs_product = torch.abs(model_weights[key]).to(device)
            elif method == "activation":
                abs_product = torch.ones_like(model_weights[key]) / torch.abs(torch.from_numpy(importance_scores[key[:-len(".weight")]])).view(-1, 1, 1, 1).to(device)
            elif method == "weight_activation":
                abs_product = torch.abs(model_weights[key]) / torch.abs(torch.from_numpy(importance_scores[key[:-len(".weight")]])).view(-1, 1, 1, 1).to(device)
            else:
                print("No this method!")
                break
            
            absolute_product[key] = abs_product
            print(f"Success:", key, model_weights[key].shape)
    except Exception as e:
        print(f"Error:{e}", key, torch.abs(model_weights[key]).shape)
importance_scores = absolute_product

Error: backbone.adjust_channel.0
Success: backbone.conv1.weight torch.Size([32, 4, 3, 3])
Success: backbone.conv2.weight torch.Size([64, 32, 3, 3])
Error:The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 0 backbone.block1.skip.weight torch.Size([128, 64, 1, 1])
Error:The size of tensor a (64) must match the size of tensor b (128) at non-singleton dimension 0 backbone.block1.rep.0.conv1.weight torch.Size([64, 1, 3, 3])
Success: backbone.block1.rep.0.pointwise.weight torch.Size([128, 64, 1, 1])
Success: backbone.block1.rep.3.conv1.weight torch.Size([128, 1, 3, 3])
Success: backbone.block1.rep.3.pointwise.weight torch.Size([128, 128, 1, 1])
Error:The size of tensor a (256) must match the size of tensor b (128) at non-singleton dimension 0 backbone.block2.skip.weight torch.Size([256, 128, 1, 1])
Error:The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 0 backbone.block2.rep.1.conv1.weight torch.Size([128, 1, 3

In [14]:
from IPython.display import clear_output

pruning_fraction = 0.001
model_class = DETECTOR[config['model_name']]
model = model_class(config)
model_weights = torch.load(f"../weights/{detector_name}.pth", map_location=device)
model.load_state_dict(model_weights)

pruned_state_dict = model.state_dict()
for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Linear, nn.Conv1d)):
        try:
            weight = module.weight.data
            score = importance_scores[name + '.weight']
            assert weight.shape == score.shape, f"Shape mismatch for {name}: weight {weight.shape}, score {score.shape}"
            num_params = weight.numel()
            num_to_prune = int(pruning_fraction * num_params)
            _, indices = torch.topk(score.view(-1), num_to_prune, largest=False)
            mask = torch.ones_like(weight).view(-1)
            mask[indices] = 0
            mask = mask.view(weight.shape)
            pruned_weight = weight * mask
            module.weight.data.copy_(pruned_weight)
            pruned_state_dict[name + '.weight'] = pruned_weight
            print(f"Pruned {name}: {num_to_prune} out of {num_params} parameters")
        except Exception as e:
            print(f"No Pruned {name}", e)
torch.save(pruned_state_dict, f"../weights/{detector_name}_prun.pth")

# Run test code
# detector_path = f"./config/detector/{detector_name}.yaml"
# weights_path = f"../weights/{detector_name}_prun.pth"
# !CUDA_VISIBLE_DEVICES=0 python test-get-confidence.py --detector_path={detector_path} --weights_path={weights_path}

# # Clear Jupyter Output
# clear_output()