# DeepSHAP Benchmark

In [1]:
import sys
sys.path.append("../")

import fastISM
from fastISM.models.basset import basset_model

from fastISM.models.factorized_basset import factorized_basset_model
from fastISM.models.bpnet import bpnet_model
import tensorflow as tf
import numpy as np
from importlib import reload
import time

import  shap

In [2]:
reload(fastISM.flatten_model)
reload(fastISM.models)
reload(fastISM.ism_base)
reload(fastISM.change_range)
reload(fastISM.fast_ism_utils)
reload(fastISM)

<module 'fastISM' from '../fastISM/__init__.py'>

In [3]:
tf.__version__

'2.3.0'

In [4]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

Num GPUs Available:  1


In [5]:
device = 'GPU:0' if tf.config.experimental.list_physical_devices('GPU') else '/device:CPU:0'
device

'GPU:0'

In [6]:
# https://github.com/kundajelab/tfmodisco_tf_models/blob/bd449328b/src/extract/dinuc_shuffle.py

def string_to_char_array(seq):
    """
    Converts an ASCII string to a NumPy array of byte-long ASCII codes.
    e.g. "ACGT" becomes [65, 67, 71, 84].
    """
    return np.frombuffer(bytes(seq, "utf8"), dtype=np.int8)


def char_array_to_string(arr):
    """
    Converts a NumPy array of byte-long ASCII codes into an ASCII string.
    e.g. [65, 67, 71, 84] becomes "ACGT".
    """
    return arr.tostring().decode("ascii")


def one_hot_to_tokens(one_hot):
    """
    Converts an L x D one-hot encoding into an L-vector of integers in the range
    [0, D], where the token D is used when the one-hot encoding is all 0. This
    assumes that the one-hot encoding is well-formed, with at most one 1 in each
    column (and 0s elsewhere).
    """
    tokens = np.tile(one_hot.shape[1], one_hot.shape[0])  # Vector of all D
    seq_inds, dim_inds = np.where(one_hot)
    tokens[seq_inds] = dim_inds
    return tokens


def tokens_to_one_hot(tokens, one_hot_dim):
    """
    Converts an L-vector of integers in the range [0, D] to an L x D one-hot
    encoding. The value `D` must be provided as `one_hot_dim`. A token of D
    means the one-hot encoding is all 0s.
    """
    identity = np.identity(one_hot_dim + 1)[:, :-1]  # Last row is all 0s
    return identity[tokens]


def dinuc_shuffle(seq, num_shufs, rng=None):
    """
    Creates shuffles of the given sequence, in which dinucleotide frequencies
    are preserved.
    Arguments:
        `seq`: either a string of length L, or an L x D NumPy array of one-hot
            encodings
        `num_shufs`: the number of shuffles to create, N
        `rng`: a NumPy RandomState object, to use for performing shuffles
    If `seq` is a string, returns a list of N strings of length L, each one
    being a shuffled version of `seq`. If `seq` is a 2D NumPy array, then the
    result is an N x L x D NumPy array of shuffled versions of `seq`, also
    one-hot encoded.
    """
    if type(seq) is str:
        arr = string_to_char_array(seq)
    elif type(seq) is np.ndarray and len(seq.shape) == 2:
        seq_len, one_hot_dim = seq.shape
        arr = one_hot_to_tokens(seq)
    else:
        raise ValueError("Expected string or one-hot encoded array")

    if not rng:
        rng = np.random.RandomState()
   
    # Get the set of all characters, and a mapping of which positions have which
    # characters; use `tokens`, which are integer representations of the
    # original characters
    chars, tokens = np.unique(arr, return_inverse=True)

    # For each token, get a list of indices of all the tokens that come after it
    shuf_next_inds = []
    for t in range(len(chars)):
        mask = tokens[:-1] == t  # Excluding last char
        inds = np.where(mask)[0]
        shuf_next_inds.append(inds + 1)  # Add 1 for next token
 
    if type(seq) is str:
        all_results = []
    else:
        all_results = np.empty(
            (num_shufs, seq_len, one_hot_dim), dtype=seq.dtype
        )

    for i in range(num_shufs):
        # Shuffle the next indices
        for t in range(len(chars)):
            inds = np.arange(len(shuf_next_inds[t]))
            inds[:-1] = rng.permutation(len(inds) - 1)  # Keep last index same
            shuf_next_inds[t] = shuf_next_inds[t][inds]

        counters = [0] * len(chars)
       
        # Build the resulting array
        ind = 0
        result = np.empty_like(tokens)
        result[0] = tokens[ind]
        for j in range(1, len(tokens)):
            t = tokens[ind]
            ind = shuf_next_inds[t][counters[t]]
            counters[t] += 1
            result[j] = tokens[ind]

        if type(seq) is str:
            all_results.append(char_array_to_string(chars[result]))
        else:
            all_results[i] = tokens_to_one_hot(chars[result], one_hot_dim)
    return all_results

In [7]:
# based on https://github.com/kundajelab/tfmodisco_tf_models/blob/bd449328b22/src/extract/compute_profile_shap.py

def create_background(model_inputs, bg_size=10, seed=20191206):
    """
    From a pair of single inputs to the model, generates the set of background
    inputs to perform interpretation against.
    Arguments:
        `model_inputs`: a pair of two entries; the first is a single one-hot
            encoded input sequence of shape I x 4; the second is the set of
            control profiles for the model, shaped T x O x 2
        `bg_size`: the number of background examples to generate.
    Returns a pair of arrays as a list, where the first array is G x I x 4, and
    the second array is G x T x O x 2; these are the background inputs. The
    background for the input sequences is randomly dinuceotide-shuffles of the
    original sequence. The background for the control profiles is the same as
    the originals.
    """
    input_seq = model_inputs[0]
    rng = np.random.RandomState(seed)
    input_seq_bg = dinuc_shuffle(input_seq, bg_size, rng=rng)
    return input_seq_bg

In [8]:
shap.__version__

'0.36.0'

In [9]:
shap.explainers._deep.deep_tf.op_handlers["AddV2"] = shap.explainers._deep.deep_tf.passthrough

## Benchmark

### Basset/Factorized Basset

In [15]:
BATCH_SIZES = [1,32,64,128,256,512, 1024]

In [16]:
# shap_values most likely internally creates a batch for each example
# thus time per 100 examples stays near constant with batch size

for model_type in [basset_model, factorized_basset_model]:
    for seqlen in [1000, 2000]:
        print("\n------------------")
        print("MODEL: {}".format(model_type))
        print("SEQLEN: {}".format(seqlen))
        model = model_type(seqlen=seqlen, num_outputs=1)
        
        # dry run 
        e = shap.DeepExplainer(model, data=create_background)
        o = e.shap_values(np.random.random((10,seqlen,4)), check_additivity=False)
        
        times = []
        per_100 = []
        for b in BATCH_SIZES:
            x = np.random.random((b,seqlen,4))
            t = time.time()
            e.shap_values(x, check_additivity=False)
            times.append(time.time()-t)
            per_100.append((times[-1]/b)*100)
            print("BATCH SIZE: {}\tTIME: {:.2f}\tPER 100: {:.2f}".format(b, times[-1], (times[-1]/b)*100))
        
        print("BEST PER 100: {:.2f}".format(min(per_100)))


------------------
MODEL: <function basset_model at 0x7fc152795b90>
SEQLEN: 1000
BATCH SIZE: 1	TIME: 0.03	PER 100: 2.56
BATCH SIZE: 32	TIME: 0.67	PER 100: 2.09
BATCH SIZE: 64	TIME: 1.19	PER 100: 1.86
BATCH SIZE: 128	TIME: 2.24	PER 100: 1.75
BATCH SIZE: 256	TIME: 4.48	PER 100: 1.75
BATCH SIZE: 512	TIME: 8.95	PER 100: 1.75
BATCH SIZE: 1024	TIME: 17.97	PER 100: 1.75
BEST PER 100: 1.75

------------------
MODEL: <function basset_model at 0x7fc152795b90>
SEQLEN: 2000
BATCH SIZE: 1	TIME: 0.03	PER 100: 3.14
BATCH SIZE: 32	TIME: 0.97	PER 100: 3.04
BATCH SIZE: 64	TIME: 1.95	PER 100: 3.05
BATCH SIZE: 128	TIME: 3.89	PER 100: 3.04
BATCH SIZE: 256	TIME: 7.79	PER 100: 3.04
BATCH SIZE: 512	TIME: 15.61	PER 100: 3.05
BATCH SIZE: 1024	TIME: 31.29	PER 100: 3.06
BEST PER 100: 3.04

------------------
MODEL: <function factorized_basset_model at 0x7fc152795c20>
SEQLEN: 1000
BATCH SIZE: 1	TIME: 0.03	PER 100: 2.69
BATCH SIZE: 32	TIME: 0.85	PER 100: 2.64
BATCH SIZE: 64	TIME: 1.69	PER 100: 2.64
BATCH SIZE: 128

### BPNet

In [10]:
BATCH_SIZES = [1,8]

In [11]:
# linear ops
shap.explainers._deep.deep_tf.op_handlers["BatchToSpaceND"] = shap.explainers._deep.deep_tf.passthrough
shap.explainers._deep.deep_tf.op_handlers["SpaceToBatchND"] = shap.explainers._deep.deep_tf.passthrough
shap.explainers._deep.deep_tf.op_handlers["Conv2DBackpropInput"] = shap.explainers._deep.deep_tf.passthrough

In [12]:
# tensorflow throws warnings that stem from creating lots of explainers
# suppress them
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [13]:
for seqlen in [1000, 2000]:
    print("\n------------------")
    print("SEQLEN: {}".format(seqlen))
    model = bpnet_model(seqlen=seqlen, num_dilated_convs=9)

    # run explainers for each position
    times = [0 for _ in BATCH_SIZES]
    for i in range(seqlen):
        e = shap.DeepExplainer((model.input, model.output[0][:,i]), data=create_background)
        # dry run
        o = e.shap_values(np.random.random((1,seqlen,4)), check_additivity=False)
        
        # batch sizes in inner loop to make explainers only once for diff batch sizes
        # making explainers is the bottleneck
        for b_idx, b in enumerate(BATCH_SIZES):
            x = np.random.random((b,seqlen,4))

            # time taken for this position (excluding time taken for setting up explainers)            
            t = time.time()
            e.shap_values(x, check_additivity=False)
            times[b_idx] += time.time()-t

    # counts output
    e = shap.DeepExplainer((model.input, model.output[1]), data=create_background)
    # dry run
    o = e.shap_values(np.random.random((1,seqlen,4)), check_additivity=False)

    for b_idx, b in enumerate(BATCH_SIZES):
        x = np.random.random((b,seqlen,4))

        # time taken for this position (excluding time taken for setting up explainers)            
        t = time.time()
        e.shap_values(x, check_additivity=False)
        times[b_idx] += time.time()-t

    per_100 = [(x/BATCH_SIZES[i])*100 for i,x in enumerate(times)]
    
    for i,x in enumerate(times):        
        print("BATCH SIZE: {}\tTIME: {:.2f}\tPER 100: {:.2f}".format(BATCH_SIZES[i], x, per_100[i]))

    print("BEST PER 100: {:.2f}".format(min(per_100)))


------------------
SEQLEN: 1000
BATCH SIZE: 1	TIME: 17.80	PER 100: 1779.57
BATCH SIZE: 8	TIME: 139.47	PER 100: 1743.35
BEST PER 100: 1743.35

------------------
SEQLEN: 2000
BATCH SIZE: 1	TIME: 65.06	PER 100: 6506.26
BATCH SIZE: 8	TIME: 514.22	PER 100: 6427.77
BEST PER 100: 6427.77
