In [2]:
import sys
sys.path.insert(0, '../../../fastshap_tf/')
from surrogate import ImageSurrogate

In [3]:
import pickle
import numpy as np
import shap
from tqdm.notebook import tqdm
import time

In [4]:
import tensorflow as tf
import tensorflow_datasets as tfds

from datetime import datetime
import os

In [5]:
from tensorflow.keras.layers import (Input, Layer, Dense)
from tensorflow.keras.models import Model

In [6]:
# IMPORTANT: SET RANDOM SEEDS FOR REPRODUCIBILITY
os.environ['PYTHONHASHSEED'] = str(420)
import random
random.seed(420)
np.random.seed(420)
tf.random.set_seed(420)

In [7]:
#Select GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## Load Data

In [8]:
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'imagenette/full-size-v2',
    split=['train', 'validation[:50%]', 'validation[-50%:]'],
    as_supervised=False,
    with_info=True
)

### Batch Data

In [9]:
def batch_data(dataset, fn, batch_size=32):
    dataset = dataset.map(fn)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

### Reformat Data

In [10]:
BATCH_SIZE = 32

def reformat(input_dict):
    
    i = input_dict['image']
    i = tf.cast(i, tf.float32)
    i = tf.image.resize_with_crop_or_pad(i, 224, 224)
    i = tf.keras.applications.resnet50.preprocess_input(i)
    
    l = tf.one_hot(input_dict['label'], depth = 10)
    
    return (i, l)

ds_train = batch_data(ds_train, reformat, BATCH_SIZE)
ds_val = batch_data(ds_val, reformat, BATCH_SIZE)
ds_test = batch_data(ds_test, reformat, BATCH_SIZE)

## Load Model

In [11]:
from tensorflow.keras.applications.resnet50 import ResNet50

INPUT_SHAPE = (224,224,3)

base_model = ResNet50(
    include_top=True, weights='imagenet', 
    input_shape=INPUT_SHAPE
)
base_model.trainable = False

model_input = Input(shape=INPUT_SHAPE, dtype='float32', name='input')

net = base_model(model_input)
out = Dense(10, activation='softmax')(net)

model = Model(model_input, out)

model_weights_path = 'model/20210511_21_28_36/model_weights.h5'

model.load_weights(model_weights_path)
model.trainable = False

# Train Surrogate

### Save Dir

In [12]:
date = datetime.now().strftime("%Y%m%d_%H_%M_%S")
save_dir = 'surrogate'
model_dir = os.path.join(os.getcwd(), save_dir, date)
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

### Train

In [33]:
surrogate = ImageSurrogate(model, model_dir)

t = time.time()
surrogate.train(train_data = ds_train, 
                val_data = ds_val, 
                max_epochs = 100, 
                batch_size = 32, 
                lookback = 10,
                lr = 1e-2)
training_time = time.time() - t

with open(os.path.join(model_dir, 'training_time.pkl'), 'wb') as f:
    pickle.dump(training_time, f)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.





To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



Epoch 1/100
Epoch 00001: val_loss improved from inf to 2.30642, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 2/100
Epoch 00002: val_loss did not improve from 2.30642
Epoch 3/100
Epoch 00003: val_loss did not improve from 2.30642
Epoch 4/100
Epoch 00004: val_loss did not improve from 2.30642

Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.008999999798834325.
Epoch 5/100
Epoch 00005: val_loss improved from 2.30642 to 2.11480, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 6/100
Epoch 00006: val_loss did not improve from 2.11480
Epoch 7/100
Epoch 00007: val_loss improved from 2.11480 to 1.83036, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 8/100
Epoch 00008: val_loss did not improve from 1.83036
Epoch 9/100
Epoch 00009: val_l

Epoch 26/100
Epoch 00026: val_loss did not improve from 1.49540
Epoch 27/100
Epoch 00027: val_loss improved from 1.49540 to 1.44669, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 28/100
Epoch 00028: val_loss improved from 1.44669 to 1.44219, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 29/100
Epoch 00029: val_loss did not improve from 1.44219
Epoch 30/100
Epoch 00030: val_loss did not improve from 1.44219
Epoch 31/100
Epoch 00031: val_loss did not improve from 1.44219

Epoch 00031: ReduceLROnPlateau reducing learning rate to 0.00531440949998796.
Epoch 32/100
Epoch 00032: val_loss improved from 1.44219 to 1.40044, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 33/100
Epoch 00033: val_loss did not improve from 1.40044
Epoch 34/100
Epoch 

Epoch 51/100
Epoch 00051: val_loss did not improve from 1.26103
Epoch 52/100
Epoch 00052: val_loss did not improve from 1.26103
Epoch 53/100
Epoch 00053: val_loss did not improve from 1.26103

Epoch 00053: ReduceLROnPlateau reducing learning rate to 0.003138105757534504.
Epoch 54/100
Epoch 00054: val_loss did not improve from 1.26103
Epoch 55/100
Epoch 00055: val_loss improved from 1.26103 to 1.24195, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 56/100
Epoch 00056: val_loss improved from 1.24195 to 1.20237, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 57/100
Epoch 00057: val_loss improved from 1.20237 to 1.19204, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 58/100
Epoch 00058: val_loss did not improve from 1.19204
Epoch 59/100
Epoch

Epoch 76/100
Epoch 00076: val_loss did not improve from 1.12409
Epoch 77/100
Epoch 00077: val_loss did not improve from 1.12409
Epoch 78/100
Epoch 00078: val_loss did not improve from 1.12409

Epoch 00078: ReduceLROnPlateau reducing learning rate to 0.0018530200235545636.
Epoch 79/100
Epoch 00079: val_loss did not improve from 1.12409
Epoch 80/100
Epoch 00080: val_loss did not improve from 1.12409
Epoch 81/100
Epoch 00081: val_loss did not improve from 1.12409

Epoch 00081: ReduceLROnPlateau reducing learning rate to 0.0016677180421538651.
Epoch 82/100
Epoch 00082: val_loss improved from 1.12409 to 1.10183, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/surrogate/20210511_21_47_45/value_weights.h5
Epoch 83/100
Epoch 00083: val_loss did not improve from 1.10183
Epoch 84/100
Epoch 00084: val_loss did not improve from 1.10183
Epoch 85/100
Epoch 00085: val_loss did not improve from 1.10183

Epoch 00085: ReduceLROnPlateau reducing learning rate to 0.00150094