# Fine pruning defense

This notebook will contain code to run the fine prining defense https://link.springer.com/chapter/10.1007/978-3-030-00470-5_13

In [1]:
import os

In [2]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
import tqdm
import numpy as np
import pandas as pd

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications.vgg16 import VGG16

from kerassurgeon import Surgeon
from kerassurgeon import identify
from kerassurgeon.operations import delete_channels

from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

In [4]:
repeats = 5
ft_split = 0.01
ft_epochs = 2
ft_lr = 0.0001
prune_stop = 0.04

tgt_layer = 'block5_conv3'
exp_dir = '../../64/'

## Repeated evaluation

In [5]:
results = []

In [6]:
for r in tqdm(range(repeats)):
    
    # Load data
    trn_x = np.load(os.path.join(exp_dir, 'trn_x.npy'))
    trn_y = np.load(os.path.join(exp_dir, 'trn_y.npy'))
    pt_x = np.load(os.path.join(exp_dir, 'pt_x.npy'))
    pt_y = np.load(os.path.join(exp_dir, 'pt_y.npy'))
    p_x = np.load(os.path.join(exp_dir, 'p_x.npy'))
    p_y = np.load(os.path.join(exp_dir, 'p_y.npy'))
    tst_x = np.load(os.path.join(exp_dir, 'tst_x.npy'))
    tst_y = np.load(os.path.join(exp_dir, 'tst_y.npy')) 
    tst_x, ft_x, tst_y, ft_y = train_test_split(tst_x, tst_y, test_size=ft_split)
    print('Shapes: \ntrn: {} - {}\npt: {} - {}\np: {} - {}\ntst: {} - {}\nft: {} - {}'.format(
        trn_x.shape, trn_y.shape, pt_x.shape, pt_y.shape, p_x.shape, p_y.shape, tst_x.shape, 
        tst_y.shape, ft_x.shape, ft_y.shape))
    pt_in_ft = [i in pt_x for i in ft_x]
    
    # Load model
    model = keras.models.load_model(os.path.join(exp_dir, 'model'))
    
    # Baselines
    orig_tst_score = model.evaluate(tst_x, tst_y, verbose=0)
    orig_pt_score = model.evaluate(pt_x, pt_y, verbose=0)
    orig_num_params = model.count_params()
    print('\nOrig tst acc: {} - orig tgt acc: {} - orig num params: {}\n'.format(
        orig_tst_score[1], orig_pt_score[1], orig_num_params))
    
    # Pruning
    stop_cond = False
    cur_idx = 0

    # Find the neurons with high frequency of zero activations
    target_layer = model.get_layer(tgt_layer)
    apoz = identify.get_apoz(model, target_layer, tst_x)
    print('Neurons in layer', len(apoz))  
    high_apoz_channels = identify.high_apoz(apoz, "both", cutoff_absolute=0.6)
    assert len(high_apoz_channels) != 0, 'Zero channels found!'

    # Prune neurons until condition is met
    while not stop_cond:
        target_layer = model.get_layer(tgt_layer)  

        model = delete_channels(model, target_layer, [high_apoz_channels[cur_idx],])
        model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
        score = model.evaluate(tst_x, tst_y, verbose=0)

        cur_idx += 1
        if orig_tst_score[1] - score[1] > prune_stop:
            stop_cond = True
        
    # Pruned evaluations
    pruned_tst_score = model.evaluate(tst_x, tst_y, verbose=0)
    pruned_pt_score = model.evaluate(pt_x, pt_y, verbose=0)
    pruned_num_params = model.count_params()
    print('\nPruned neurons: {} - pruned tst acc: {} - pruned tgt acc: {} - pruned num params: {}\n'.format(
        cur_idx, pruned_tst_score[1], pruned_pt_score[1], pruned_num_params))
    
    # Fine tuning
    model.optimizer.learning_rate.assign(ft_lr)
    model.fit(ft_x, ft_y, batch_size=32, epochs=ft_epochs)
    
    # Fine tuned evaluations
    ft_tst_score = model.evaluate(tst_x, tst_y, verbose=0)
    ft_pt_score = model.evaluate(pt_x, pt_y, verbose=0)
    ft_num_params = model.count_params()
    print('\Ft tst acc: {} - ft tgt acc: {} - ft num params: {}\n'.format(
        ft_tst_score[1], ft_pt_score[1], ft_num_params))
    
    # Accumulate results
    results.append({
        'repetition': r,
        'orig_tst_acc': orig_tst_score[1],
        'orig_tgt_acc': orig_pt_score[1],
        'orig_params': orig_num_params,
        'num_pruned_neurons': cur_idx,
        'pruned_tst_acc': pruned_tst_score[1],
        'pruned_tgt_acc': pruned_pt_score[1],
        'pruned_params': pruned_num_params,
        'ft_tst_acc': ft_tst_score[1],
        'ft_tgt_acc': ft_pt_score[1],
        'ft_params': ft_num_params,
        'pt_in_ft': sum(pt_in_ft)
    })
    
    # Cleanup
    del trn_x, trn_y, pt_x, pt_y, p_x, p_y, tst_x, tst_y, ft_x, ft_y, model
    keras.backend.clear_session()
    print('-'*80, '\n')

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

Shapes: 
trn: (7144, 100, 100, 3) - (7144, 2)
pt: (37, 100, 100, 3) - (37, 2)
p: (144, 100, 100, 3) - (144, 2)
tst: (5993, 100, 100, 3) - (5993, 2)
ft: (61, 100, 100, 3) - (61, 2)

Orig tst acc: 0.8384782075881958 - orig tgt acc: 0.7567567825317383 - orig num params: 14715714

Neurons in layer 512
Deleting 1/512 channels from layer: block5_conv3
Deleting 1/511 channels from layer: block5_conv3
Deleting 1/510 channels from layer: block5_conv3
Deleting 1/509 channels from layer: block5_conv3
Deleting 1/508 channels from layer: block5_conv3
Deleting 1/507 channels from layer: block5_conv3
Deleting 1/506 channels from layer: block5_conv3
Deleting 1/505 channels from layer: block5_conv3
Deleting 1/504 channels from layer: block5_conv3
Deleting 1/503 channels from layer: block5_conv3
Deleting 1/502 channels from layer: block5_conv3
Deleting 1/501 channels from layer: block5_conv3
Deleting 1/500 channels from layer: block5_conv3
Deleting 1/499 channels from layer: block5_conv3
Deleting 1/498 

Deleting 1/486 channels from layer: block5_conv3
Deleting 1/485 channels from layer: block5_conv3
Deleting 1/484 channels from layer: block5_conv3
Deleting 1/483 channels from layer: block5_conv3
Deleting 1/482 channels from layer: block5_conv3
Deleting 1/481 channels from layer: block5_conv3
Deleting 1/480 channels from layer: block5_conv3
Deleting 1/479 channels from layer: block5_conv3
Deleting 1/478 channels from layer: block5_conv3
Deleting 1/477 channels from layer: block5_conv3
Deleting 1/476 channels from layer: block5_conv3
Deleting 1/475 channels from layer: block5_conv3
Deleting 1/474 channels from layer: block5_conv3
Deleting 1/473 channels from layer: block5_conv3
Deleting 1/472 channels from layer: block5_conv3
Deleting 1/471 channels from layer: block5_conv3
Deleting 1/470 channels from layer: block5_conv3
Deleting 1/469 channels from layer: block5_conv3
Deleting 1/468 channels from layer: block5_conv3
Deleting 1/467 channels from layer: block5_conv3
Deleting 1/466 chann

Deleting 1/448 channels from layer: block5_conv3
Deleting 1/447 channels from layer: block5_conv3
Deleting 1/446 channels from layer: block5_conv3
Deleting 1/445 channels from layer: block5_conv3
Deleting 1/444 channels from layer: block5_conv3
Deleting 1/443 channels from layer: block5_conv3
Deleting 1/442 channels from layer: block5_conv3
Deleting 1/441 channels from layer: block5_conv3
Deleting 1/440 channels from layer: block5_conv3
Deleting 1/439 channels from layer: block5_conv3
Deleting 1/438 channels from layer: block5_conv3
Deleting 1/437 channels from layer: block5_conv3
Deleting 1/436 channels from layer: block5_conv3
Deleting 1/435 channels from layer: block5_conv3
Deleting 1/434 channels from layer: block5_conv3
Deleting 1/433 channels from layer: block5_conv3
Deleting 1/432 channels from layer: block5_conv3
Deleting 1/431 channels from layer: block5_conv3
Deleting 1/430 channels from layer: block5_conv3
Deleting 1/429 channels from layer: block5_conv3
Deleting 1/428 chann



Epoch 2/2
\Ft tst acc: 0.7088269591331482 - ft tgt acc: 0.8648648858070374 - ft num params: 14190060

-------------------------------------------------------------------------------- 

Shapes: 
trn: (7144, 100, 100, 3) - (7144, 2)
pt: (37, 100, 100, 3) - (37, 2)
p: (144, 100, 100, 3) - (144, 2)
tst: (5993, 100, 100, 3) - (5993, 2)
ft: (61, 100, 100, 3) - (61, 2)

Orig tst acc: 0.8396462798118591 - orig tgt acc: 0.7567567825317383 - orig num params: 14715714

Neurons in layer 512
Deleting 1/512 channels from layer: block5_conv3
Deleting 1/511 channels from layer: block5_conv3
Deleting 1/510 channels from layer: block5_conv3
Deleting 1/509 channels from layer: block5_conv3
Deleting 1/508 channels from layer: block5_conv3
Deleting 1/507 channels from layer: block5_conv3
Deleting 1/506 channels from layer: block5_conv3
Deleting 1/505 channels from layer: block5_conv3
Deleting 1/504 channels from layer: block5_conv3
Deleting 1/503 channels from layer: block5_conv3
Deleting 1/502 channels fr



Epoch 2/2
\Ft tst acc: 0.5357917547225952 - ft tgt acc: 0.0810810774564743 - ft num params: 14190060

-------------------------------------------------------------------------------- 

Shapes: 
trn: (7144, 100, 100, 3) - (7144, 2)
pt: (37, 100, 100, 3) - (37, 2)
p: (144, 100, 100, 3) - (144, 2)
tst: (5993, 100, 100, 3) - (5993, 2)
ft: (61, 100, 100, 3) - (61, 2)

Orig tst acc: 0.8388119339942932 - orig tgt acc: 0.7567567825317383 - orig num params: 14715714

Neurons in layer 512
Deleting 1/512 channels from layer: block5_conv3
Deleting 1/511 channels from layer: block5_conv3
Deleting 1/510 channels from layer: block5_conv3
Deleting 1/509 channels from layer: block5_conv3
Deleting 1/508 channels from layer: block5_conv3
Deleting 1/507 channels from layer: block5_conv3
Deleting 1/506 channels from layer: block5_conv3
Deleting 1/505 channels from layer: block5_conv3
Deleting 1/504 channels from layer: block5_conv3
Deleting 1/503 channels from layer: block5_conv3
Deleting 1/502 channels fr



Epoch 2/2
\Ft tst acc: 0.5172701478004456 - ft tgt acc: 0.8918918967247009 - ft num params: 14190060

-------------------------------------------------------------------------------- 




In [7]:
res_df = pd.DataFrame.from_dict(results)

In [8]:
res_df

Unnamed: 0,repetition,orig_tst_acc,orig_tgt_acc,orig_params,num_pruned_neurons,pruned_tst_acc,pruned_tgt_acc,pruned_params,ft_tst_acc,ft_tgt_acc,ft_params,pt_in_ft
0,0,0.838478,0.756757,14715714,119,0.795428,0.486486,14167005,0.519106,0.351351,14167005,1
1,1,0.838478,0.756757,14715714,114,0.798098,0.513514,14190060,0.604205,0.351351,14190060,1
2,2,0.838144,0.756757,14715714,114,0.798098,0.513514,14190060,0.708827,0.864865,14190060,0
3,3,0.839646,0.756757,14715714,114,0.799099,0.513514,14190060,0.535792,0.081081,14190060,1
4,4,0.838812,0.756757,14715714,114,0.798432,0.513514,14190060,0.51727,0.891892,14190060,2


In [9]:
res_df.mean()

repetition            2.000000e+00
orig_tst_acc          8.387118e-01
orig_tgt_acc          7.567568e-01
orig_params           1.471571e+07
num_pruned_neurons    1.150000e+02
pruned_tst_acc        7.978308e-01
pruned_tgt_acc        5.081081e-01
pruned_params         1.418545e+07
ft_tst_acc            5.770399e-01
ft_tgt_acc            5.081081e-01
ft_params             1.418545e+07
pt_in_ft              1.000000e+00
dtype: float64