In [2]:
import numpy as np
import sys

import tensorflow as tf
from tensorflow.keras.layers import (Input, Layer, Dense, Lambda, 
                                     Dropout, Multiply, BatchNormalization, 
                                     Reshape, Concatenate, Conv2D, Permute)
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras import regularizers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

import tensorflow_datasets as tfds

from datetime import datetime
import os

#Select GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [4]:
# IMPORTANT: SET RANDOM SEEDS FOR REPRODUCIBILITY
os.environ['PYTHONHASHSEED'] = str(420)
import random
random.seed(420)
np.random.seed(420)
tf.random.set_seed(420)

# Train Model to Be Explained

### Parameters

In [5]:
BATCH_SIZE = 32
EPOCHS = 50
LR = 1e-2
INPUT_SHAPE = (224, 224, 3)

### Load Data

In [6]:
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'imagenette/full-size-v2',
    split=['train', 'validation[:50%]', 'validation[-50%:]'],
    as_supervised=False,
    with_info=True
)

### Batch Data

In [7]:
def batch_data(dataset, fn, batch_size=32):
    dataset = dataset.map(fn)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

### Reformat Data

In [8]:
def reformat(input_dict):
    
    i = input_dict['image']
    i = tf.cast(i, tf.float32)
    i = tf.image.resize_with_crop_or_pad(i, 224, 224)
    i = tf.keras.applications.resnet50.preprocess_input(i)
    
    l = tf.one_hot(input_dict['label'], depth = 10)
    
    return (i, l)

ds_train = batch_data(ds_train, reformat, BATCH_SIZE)
ds_val = batch_data(ds_val, reformat, BATCH_SIZE)
ds_test = batch_data(ds_test, reformat, BATCH_SIZE)

### Model

In [9]:
from tensorflow.keras.applications.resnet50 import ResNet50

base_model = ResNet50(
    include_top=True, weights='imagenet', 
    input_shape=INPUT_SHAPE
)
base_model.trainable = False

model_input = Input(shape=INPUT_SHAPE, dtype='float32', name='input')

net = base_model(model_input)
out = Dense(10, activation='softmax')(net)

model = Model(model_input, out)

# Metrics
METRICS = [ 
  tf.keras.metrics.AUC(name='auroc'),
  tf.keras.metrics.AUC(curve='PR', name='auprc'),
  tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='accuracy'),
]

# Model Checkpointing
time = datetime.now().strftime("%Y%m%d_%H_%M_%S")
save_dir = 'model'
model_dir = os.path.join(os.getcwd(), save_dir, time)
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)
model_weights_path = os.path.join(model_dir, 'model_weights.h5')
checkpoint = ModelCheckpoint(model_weights_path, monitor='val_loss', verbose=1, 
                             save_best_only=True, mode='min', save_weights_only=True)

# LR Schedule
reduceLR = ReduceLROnPlateau(monitor='val_loss', factor=0.95, patience=3, 
                             verbose=1, mode='min', cooldown=1, min_lr=1e-4)

# Early Stopping 
earlyStop = EarlyStopping(monitor="val_loss", mode="min", patience=10) 

# Compile Model
CALLBACKS = [checkpoint, earlyStop, reduceLR]
OPTIMIZER = tf.keras.optimizers.Adam(LR)

model.compile(
    loss='categorical_crossentropy',
    optimizer=OPTIMIZER,
    metrics=METRICS,
)

# Train Model
model.fit(ds_train,
          epochs = EPOCHS,
          validation_data = ds_val,
          callbacks = CALLBACKS)

# Get Checkpointed Model
print(model_weights_path)
model.load_weights(model_weights_path)
model.trainable = False

# Evaluate
model.evaluate(ds_test)

Epoch 1/50
Epoch 00001: val_loss improved from inf to 0.39218, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/model/20210511_21_28_36/model_weights.h5
Epoch 2/50
Epoch 00002: val_loss improved from 0.39218 to 0.20246, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/model/20210511_21_28_36/model_weights.h5
Epoch 3/50
Epoch 00003: val_loss improved from 0.20246 to 0.14931, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/model/20210511_21_28_36/model_weights.h5
Epoch 4/50
Epoch 00004: val_loss improved from 0.14931 to 0.12468, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/model/20210511_21_28_36/model_weights.h5
Epoch 5/50
Epoch 00005: val_loss improved from 0.12468 to 0.11077, saving model to /gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/model/20210511_21_28_36/model_weights.h5
Epoch 6/50
Epoch 00006: val_loss improved from 0.11077 to 0.1020

Epoch 18/50
Epoch 00018: val_loss did not improve from 0.08587
Epoch 19/50
Epoch 00019: val_loss did not improve from 0.08587

Epoch 00019: ReduceLROnPlateau reducing learning rate to 0.009024999709799886.
Epoch 20/50
Epoch 00020: val_loss did not improve from 0.08587
Epoch 21/50
Epoch 00021: val_loss did not improve from 0.08587
Epoch 22/50
Epoch 00022: val_loss did not improve from 0.08587

Epoch 00022: ReduceLROnPlateau reducing learning rate to 0.008573750033974648.
Epoch 23/50
Epoch 00023: val_loss did not improve from 0.08587
Epoch 24/50
Epoch 00024: val_loss did not improve from 0.08587
/gpfs/data/paulab/nj594/fast_shap/experiments/images/imagenette/model/20210511_21_28_36/model_weights.h5


[0.0980679988861084,
 0.9989628791809082,
 0.9955040812492371,
 0.9714577198028564]