In [1]:
import sys
sys.path.insert(0, '../../../fastshap_tf/')
from fastshap import ImageFastSHAP
from utils import ShapleySampler, ResizeMask

RuntimeError: module compiled against API version 0xe but this version of numpy is 0xd

In [2]:
import pickle
import numpy as np
import shap
from tqdm.notebook import tqdm
import time

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds

from datetime import datetime
import os

In [4]:
from tensorflow.keras.layers import (Input, Layer, Dense, Lambda, Reshape, Multiply)
from tensorflow.keras.models import Model, Sequential

In [5]:
#Select GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [6]:
# IMPORTANT: SET RANDOM SEEDS FOR REPRODUCIBILITY
os.environ['PYTHONHASHSEED'] = str(420)
import random
random.seed(420)
np.random.seed(420)
tf.random.set_seed(420)

## Load Data

In [7]:
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'imagenette/full-size-v2',
    split=['train', 'validation[:50%]', 'validation[-50%:]'],
    as_supervised=False,
    with_info=True
)

### Batch Data

In [8]:
def batch_data(dataset, fn, batch_size=32):
    dataset = dataset.map(fn)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

### Reformat Data

In [9]:
BATCH_SIZE = 32

def reformat(input_dict):
    
    i = input_dict['image']
    i = tf.cast(i, tf.float32)
    i = tf.image.resize_with_crop_or_pad(i, 224, 224)
    i = tf.keras.applications.resnet50.preprocess_input(i)
    
    l = tf.one_hot(input_dict['label'], depth = 10)
    
    return (i, l)

ds_train = batch_data(ds_train, reformat, BATCH_SIZE)
ds_val = batch_data(ds_val, reformat, BATCH_SIZE)
ds_test = batch_data(ds_test, reformat, BATCH_SIZE)

## Load Surrogate Imputer

In [10]:
from tensorflow.keras.applications.resnet50 import ResNet50

input_shape = (224,224,3)
P = 14*14
value_model = ResNet50(
    include_top=False, weights='imagenet', 
    input_shape=input_shape, pooling='avg'
) 
D = 10

model_input = Input(shape=input_shape, dtype='float64', name='input')
S = ShapleySampler(P, paired_sampling=False, num_samples=1)(model_input)
S = Lambda(lambda x: tf.cast(x, tf.float32))(S)
S = Reshape((P,))(S)
S = ResizeMask(in_shape=input_shape, mask_size=P)(S)
xs = Multiply()([model_input, S])

net = value_model(xs)
out = Dense(D, activation='softmax')(net)

surrogate = Model(model_input, out)

# Get Checkpointed Model
weights_path = 'surrogate/20210511_21_47_45/value_weights.h5'
surrogate.load_weights(weights_path)

# Remove Masking Layer
# Remove Masking Layer
surrogate = Sequential(   
    [l for l in surrogate.layers[-2:]]
)
surrogate.trainable = False



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



# Train FastSHAP

### Save Dir

In [11]:
date = datetime.now().strftime("%Y%m%d_%H_%M_%S")
save_dir = 'fastshap'
model_dir = os.path.join(os.getcwd(), save_dir, date)
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

### Initialize

In [12]:
from importlib import reload
import fastshap
import utils
reload(fastshap)
reload(utils)
from fastshap import ImageFastSHAP

In [13]:
fastshap = ImageFastSHAP(imputer = surrogate ,
                         normalization=None,
                         model_dir = model_dir, 
                         link='logit')

### Train

In [14]:
t = time.time()
fastshap.train(train_data = ds_train, 
              val_data = ds_val, 
              max_epochs = 100, 
              batch_size = 32, 
              num_samples = 1,
              lr = 1e-3,
              paired_sampling = True, 
              eff_lambda = 0.0,
              verbose = 1,
              lookback = 20)
training_time = time.time() - t

with open(os.path.join(model_dir, 'training_time.pkl'), 'wb') as f:
    pickle.dump(training_time, f)

The following Variables were used a Lambda layer's call (lambda_1), but
are not present in its tracked objects:
  <tf.Variable 'conv1_conv/kernel:0' shape=(7, 7, 3, 64) dtype=float32>
  <tf.Variable 'conv1_conv/bias:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv1_bn/gamma:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv1_bn/beta:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_0_conv/kernel:0' shape=(1, 1, 64, 256) dtype=float32>
  <tf.Variable 'conv2_block1_0_conv/bias:0' shape=(256,) dtype=float32>
  <tf.Variable 'conv2_block1_0_bn/gamma:0' shape=(256,) dtype=float32>
  <tf.Variable 'conv2_block1_0_bn/beta:0' shape=(256,) dtype=float32>
  <tf.Variable 'conv2_block1_1_conv/kernel:0' shape=(1, 1, 64, 64) dtype=float32>
  <tf.Variable 'conv2_block1_1_conv/bias:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_1_bn/gamma:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_1_bn/beta:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_2_conv/kernel:0'

The following Variables were used a Lambda layer's call (lambda_1), but
are not present in its tracked objects:
  <tf.Variable 'conv1_conv/kernel:0' shape=(7, 7, 3, 64) dtype=float32>
  <tf.Variable 'conv1_conv/bias:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv1_bn/gamma:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv1_bn/beta:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_0_conv/kernel:0' shape=(1, 1, 64, 256) dtype=float32>
  <tf.Variable 'conv2_block1_0_conv/bias:0' shape=(256,) dtype=float32>
  <tf.Variable 'conv2_block1_0_bn/gamma:0' shape=(256,) dtype=float32>
  <tf.Variable 'conv2_block1_0_bn/beta:0' shape=(256,) dtype=float32>
  <tf.Variable 'conv2_block1_1_conv/kernel:0' shape=(1, 1, 64, 64) dtype=float32>
  <tf.Variable 'conv2_block1_1_conv/bias:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_1_bn/gamma:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_1_bn/beta:0' shape=(64,) dtype=float32>
  <tf.Variable 'conv2_block1_2_conv/kernel:0'

Epoch 1/100
Epoch 00001: val_shap_loss improved from inf to 24.42870, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 2/100
Epoch 00002: val_shap_loss improved from 24.42870 to 22.15116, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 3/100
Epoch 00003: val_shap_loss did not improve from 22.15116
Epoch 4/100
Epoch 00004: val_shap_loss improved from 22.15116 to 17.93956, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 5/100
Epoch 00005: val_shap_loss did not improve from 17.93956
Epoch 6/100
Epoch 00006: val_shap_loss improved from 17.93956 to 17.54804, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 7/100
Epoch 00007: val_shap_loss improved from 17

Epoch 21/100
Epoch 00021: val_shap_loss did not improve from 17.15061
Epoch 22/100
Epoch 00022: val_shap_loss did not improve from 17.15061

Epoch 00022: ReduceLROnPlateau reducing learning rate to 0.00032768002711236477.
Epoch 23/100
Epoch 00023: val_shap_loss did not improve from 17.15061
Epoch 24/100
Epoch 00024: val_shap_loss improved from 17.15061 to 16.74248, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 25/100
Epoch 00025: val_shap_loss improved from 16.74248 to 16.50909, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 26/100
Epoch 00026: val_shap_loss did not improve from 16.50909
Epoch 27/100
Epoch 00027: val_shap_loss did not improve from 16.50909
Epoch 28/100
Epoch 00028: val_shap_loss did not improve from 16.50909

Epoch 00028: ReduceLROnPlateau reducing learning rate to 0.0002621440216898918.
Epoch 29/

Epoch 41/100
Epoch 00041: val_shap_loss improved from 15.59433 to 15.07786, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 42/100
Epoch 00042: val_shap_loss improved from 15.07786 to 15.06184, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/fastshap/20210519_16_09_16/explainer_weights.h5
Epoch 43/100
Epoch 00043: val_shap_loss did not improve from 15.06184
Epoch 44/100
Epoch 00044: val_shap_loss did not improve from 15.06184
Epoch 45/100
Epoch 00045: val_shap_loss did not improve from 15.06184

Epoch 00045: ReduceLROnPlateau reducing learning rate to 0.00010737419361248613.
Epoch 46/100
Epoch 00046: val_shap_loss did not improve from 15.06184
Epoch 47/100
Epoch 00047: val_shap_loss did not improve from 15.06184
Epoch 48/100
Epoch 00048: val_shap_loss did not improve from 15.06184

Epoch 00048: ReduceLROnPlateau reducing learning rate to 8.589935605414213e-05.
Epoch 49/

Epoch 61/100
Epoch 00061: val_shap_loss did not improve from 14.80786

Epoch 00061: ReduceLROnPlateau reducing learning rate to 4.398046876303852e-05.
Epoch 62/100
Epoch 00062: val_shap_loss did not improve from 14.80786
Epoch 63/100
Epoch 00063: val_shap_loss did not improve from 14.80786
Epoch 64/100
Epoch 00064: val_shap_loss did not improve from 14.80786

Epoch 00064: ReduceLROnPlateau reducing learning rate to 3.518437442835421e-05.
Epoch 65/100
Epoch 00065: val_shap_loss did not improve from 14.80786
Epoch 66/100
Epoch 00066: val_shap_loss did not improve from 14.80786
Epoch 67/100
Epoch 00067: val_shap_loss did not improve from 14.80786

Epoch 00067: ReduceLROnPlateau reducing learning rate to 2.8147498960606756e-05.
Epoch 68/100
Epoch 00068: val_shap_loss did not improve from 14.80786
Epoch 69/100
Epoch 00069: val_shap_loss did not improve from 14.80786
Epoch 70/100
Epoch 00070: val_shap_loss did not improve from 14.80786

Epoch 00070: ReduceLROnPlateau reducing learning rate t

# Explain w/ FastSHAP

### Load Images

In [15]:
images_dir = os.path.join(os.getcwd(), 'images')
images = np.load(os.path.join(images_dir, 'processed_images.npy'), allow_pickle=True)

### Explain

In [18]:
t = time.time()
shap_values = fastshap.explainer.predict(images)
explaining_time = time.time() - t
shap_values = [shap_values[:,:,:,i] for i in range(10)]

### Save

In [19]:
# save_dir = 'fastshap'
# model_dir = os.path.join(os.getcwd(), save_dir, 'results')
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

with open(os.path.join(model_dir, 'explaining_time.pkl'), 'wb') as f:
    pickle.dump(explaining_time, f)
    
with open(os.path.join(model_dir, 'shap_values.pkl'), 'wb') as f:
    pickle.dump(shap_values, f)