# Optimise ISM for the Basset Architecture

Proof of concept in tf/Keras, essentially a copy of the PyTorch notebook. 

One key difference is this noteobok would forward prop with multiple examples, each having mutation in the same ith regions. This is different compared to the PyTorch notebook, in which each forward prop has the same sequence with mutations at different positions, which made the bookkeeping harder. Inspired by discussions with Av.

Architecture (PyTorch): [link](https://github.com/kundajelab/GenoPyT/blob/c84f38dfaa0c986f91383dd7e6278c1cb993498d/src/models/sequence_only/basset.py). Padding changed for last layer since Keras allows 'same' or 'valid'.

In [20]:
import tensorflow as tf
import numpy as np

from collections import Counter
from copy import deepcopy

In [5]:
tf.__version__

'2.3.0'

In [8]:
def get_idxs_conv_maxpool(seqlen, kernelsize, padding, maxpool_kernel, change_ranges, 
                          conv_stride=1,
                          maxpool_stride=None,
                          maxpool_ceil_mode=False): 

    # assumes stride==1 for conv and stride=kernel for maxpool 
    # change ranges are BEFORE padding 
    # indexes returned are slices AFTER padding input seqs
    
    if maxpool_ceil_mode==True or conv_stride!=1 or maxpool_stride!=None: 
        # will take extra care, e.g. repeat values in last block  
        raise NotImplementedError 
     
    # raw ranges for each change_range -- this is the input range in which
    # changing the change_range will affect the output
    raw_seq_ranges = [(x-kernelsize+1,y+kernelsize-1) for x,y in change_ranges] 
     
    # re-adjust since there will be `padding` number of zeros in the beginning 
    raw_seq_pad_adjusted = [(x+padding, y+padding) for x,y in raw_seq_ranges] 
     
    range_corrected = [] 
    for x,y in raw_seq_pad_adjusted: 
        # shift around the edges
        if x<0 and y>seqlen+2*padding: # kinda degenerate, required when using for fc layers
            range_corrected.append((0,seqlen+2*padding))
        elif x<0: 
            range_corrected.append((0, y-x)) 
        elif y > seqlen+2*padding: 
            range_corrected.append((x-(y-seqlen-2*padding),seqlen+2*padding)) 
        else: 
            range_corrected.append((x,y)) 

    # the conv output range affected by each input
    conv_out_ranges = [(x,y-kernelsize+1) for x,y in range_corrected] 

    # length of sequence after convolution
    conv_seqlen = seqlen + 2*padding - kernelsize + 1
    
    # shift to the edges of the nearest maxpool block
    mod_shifted = [(maxpool_kernel*(x//maxpool_kernel), 
                   maxpool_kernel*((y-1)//maxpool_kernel+1)) for x,y in conv_out_ranges] 
    # each should be the same size
    maxwidth = max([y-x for x,y in mod_shifted])  

    mod_shifted = [(x,x+maxwidth) if y<=conv_seqlen else (y-maxwidth, y) for x,y in mod_shifted]  
     
    # when ceil_mode==False, this works by ignoring last block [ceil_mode==False also ignores last block]  
    mod_shifted = [(x,y) if y<=conv_seqlen else (x-maxpool_kernel,y-maxpool_kernel) for x,y in mod_shifted]  
    assert([y<=conv_seqlen for _,y in mod_shifted])  
    
    # this would be the output ranges AFTER maxpool
    out_ranges = [(x//maxpool_kernel, y//maxpool_kernel) for x,y in mod_shifted]  
     
    # work back input slices for desired output maxpool ranges
    slice_ranges = [(x,y+kernelsize-1) for x,y in mod_shifted] 
    
    offsets = [x+padding-slice_ranges[i][0] for i,(x,_) in enumerate(change_ranges)] 
    
    return (slice_ranges, offsets), out_ranges 

In [123]:
MUT_POS = 100

In [124]:
([s_slice], _), [mxp1_out_range] = get_idxs_conv_maxpool(1000, 19, 9, 3, [(MUT_POS,MUT_POS+1)])
inp_width = s_slice[1]-s_slice[0]
print(inp_width)

39


In [125]:
# input sequence 
inp_seq_np = np.random.randn(1000,1000,4)
inp_seq_np = inp_seq_np.astype(np.float32)
inp_seq_perturbed_np = np.copy(inp_seq_np)
inp_seq_perturbed_np[:, MUT_POS, :] = 0

inp_seq = tf.constant(inp_seq_np)
inp_seq_perturbed = tf.constant(inp_seq_perturbed_np)

In [126]:
padded_inp_seq = tf.concat([tf.zeros((1000,9,4)), 
                            inp_seq_perturbed, 
                            tf.zeros((1000,9,4))], 
                           axis=1)
padded_inp_seq.shape

TensorShape([1000, 1018, 4])

In [127]:
s_slice

(90, 129)

In [128]:
%timeit padded_inp_seq[:, s_slice[0]:s_slice[1], :]

118 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Conv1 + Maxpool1

In [178]:
l1 = tf.keras.Sequential()
l1.add(tf.keras.layers.Conv1D(300, 19, strides=1, padding='valid'))
l1.add(tf.keras.layers.BatchNormalization())
l1.add(tf.keras.layers.MaxPool1D(3))

l1_w_padding = tf.keras.models.clone_model(l1)
l1_w_padding.layers[0].padding = 'same'

l1.build(input_shape=(None,s_slice[1]-s_slice[0],4))
l1_w_padding.build(input_shape=(None, 1000, 4))
l1_w_padding.layers[0].set_weights(l1.layers[0].get_weights()) # copy conv weights
# not copying batch norm weights for now

In [186]:
mxp1_out = l1.predict(padded_inp_seq[:, s_slice[0]:s_slice[1], :])
mxp1_out.shape

(1000, 7, 300)

In [187]:
padded_inp_seq.shape

TensorShape([1000, 1018, 4])

In [188]:
inp_seq_perturbed.shape

TensorShape([1000, 1000, 4])

In [189]:
%timeit l1.predict(padded_inp_seq[:, s_slice[0]:s_slice[1], :])

53.6 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [190]:
mxp1_ism_out = l1_w_padding.predict(inp_seq_perturbed)
mxp1_ism_out.shape

(1000, 333, 300)

In [191]:
%timeit l1_w_padding.predict(inp_seq_perturbed)

915 ms ± 19.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [134]:
mxp1_out_range

(30, 37)

### Compare

In [242]:
tf.reduce_all(tf.equal(mxp1_out, mxp1_ism_out[:,range(*mxp1_out_range)]))

<tf.Tensor: shape=(), dtype=bool, numpy=True>

---

In [222]:
mxp1_out_ref = l1_w_padding.predict(inp_seq)
mxp1_out_ref.shape

(1000, 333, 300)

In [203]:
([mxp1_out_slice], [mxp1_out_offset]), [mxp2_out_range] = get_idxs_conv_maxpool(333, 11, 5, 4, [mxp1_out_range])

In [207]:
conv2_inp_width = mxp1_out_slice[1]-mxp1_out_slice[0]
print(conv2_inp_width)

30


In [223]:
conv2_inp_num_channels = mxp1_out.shape[2]
conv2_inp_num_channels

300

In [213]:
mxp1_out_offset

11

In [218]:
mxp1_out_slice

(24, 54)

In [243]:
mxp2_out_range

(6, 11)

In [247]:
mxp1_out_ref.shape

(1000, 333, 300)

In [251]:
padded_mxp1_out_ref = tf.concat([tf.zeros((1000,5,mxp1_out_ref.shape[2])), 
                            mxp1_out_ref, 
                            tf.zeros((1000,5,mxp1_out_ref.shape[2]))], 
                           axis=1)
padded_mxp1_out_ref.shape

TensorShape([1000, 343, 300])

In [252]:
conv2_inp = tf.concat([padded_mxp1_out_ref[:, mxp1_out_slice[0]:mxp1_out_slice[0]+mxp1_out_offset],
                      mxp1_out,
                      padded_mxp1_out_ref[:, mxp1_out_slice[0]+mxp1_out_offset+mxp1_out.shape[1]:mxp1_out_slice[1]]],
                     axis=1)
conv2_inp.shape

TensorShape([1000, 30, 300])

## Conv2 + Maxpool2

In [253]:
l2 = tf.keras.Sequential()
l2.add(tf.keras.layers.Conv1D(200, 11, strides=1, padding='valid'))
l2.add(tf.keras.layers.BatchNormalization())
l2.add(tf.keras.layers.MaxPool1D(4))

l2_w_padding = tf.keras.models.clone_model(l2)
l2_w_padding.layers[0].padding = 'same'

l2.build(input_shape=(None,mxp1_out_slice[1]-mxp1_out_slice[0],conv2_inp_num_channels))
l2_w_padding.build(input_shape=(None, mxp1_ism_out.shape[1], conv2_inp_num_channels))
l2_w_padding.layers[0].set_weights(l2.layers[0].get_weights()) # copy conv weights
# not copying batch norm weights for now

In [254]:
mxp2_out = l2(conv2_inp)
mxp2_out.shape

TensorShape([1000, 5, 200])

In [255]:
def f():
    l2(tf.concat([mxp1_out_ref[:, mxp1_out_slice[0]:mxp1_out_slice[0]+mxp1_out_offset],
                      mxp1_out,
                      mxp1_out_ref[:, mxp1_out_slice[0]+mxp1_out_offset+mxp1_out.shape[1]:mxp1_out_slice[1]]],
                     axis=1))

In [256]:
%timeit f()

94.1 ms ± 1.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [257]:
mxp2_ism_out = l2_w_padding.predict(mxp1_ism_out)
mxp2_ism_out.shape

(1000, 83, 200)

In [258]:
%timeit l2_w_padding.predict(mxp1_ism_out)

1.44 s ± 16.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Compare

In [287]:
tf.reduce_all(tf.equal(mxp2_out, mxp2_ism_out[:,range(*mxp2_out_range)]))

<tf.Tensor: shape=(), dtype=bool, numpy=True>

---

In [288]:
mxp2_out_ref = l2_w_padding.predict(mxp1_out_ref)
mxp2_out_ref.shape

(1000, 83, 200)

In [289]:
([mxp2_out_slice], [mxp2_out_offset]), [mxp3_out_range] = get_idxs_conv_maxpool(83, 7, 3, 4, [mxp2_out_range])

In [290]:
conv3_inp_width = mxp2_out_slice[1]-mxp2_out_slice[0]
print(conv3_inp_width)

22


In [291]:
conv3_inp_num_channels = mxp2_out.shape[2]
conv3_inp_num_channels

200

In [292]:
mxp2_out_offset

9

In [293]:
mxp2_out_slice

(0, 22)

In [294]:
mxp3_out_range

(0, 4)

In [302]:
padded_mxp2_out_ref = tf.concat([tf.zeros((1000,3,mxp2_out_ref.shape[2])), 
                            mxp2_out_ref, 
                            tf.zeros((1000,3,mxp2_out_ref.shape[2]))], 
                           axis=1)
padded_mxp2_out_ref.shape

TensorShape([1000, 89, 200])

In [303]:
conv3_inp = tf.concat([padded_mxp2_out_ref[:, mxp2_out_slice[0]:mxp2_out_slice[0]+mxp2_out_offset],
                      mxp2_out,
                      padded_mxp2_out_ref[:, mxp2_out_slice[0]+mxp2_out_offset+mxp2_out.shape[1]:mxp2_out_slice[1]]],
                     axis=1)
conv3_inp.shape

TensorShape([1000, 22, 200])

## Conv3 + Maxpool3

In [304]:
l3 = tf.keras.Sequential()
l3.add(tf.keras.layers.Conv1D(200, 7, strides=1, padding='valid'))
l3.add(tf.keras.layers.BatchNormalization())
l3.add(tf.keras.layers.MaxPool1D(4))

l3_w_padding = tf.keras.models.clone_model(l3)
l3_w_padding.layers[0].padding = 'same'

l3.build(input_shape=(None,mxp2_out_slice[1]-mxp2_out_slice[0],conv3_inp_num_channels))
l3_w_padding.build(input_shape=(None, mxp2_ism_out.shape[1], conv3_inp_num_channels))
l3_w_padding.layers[0].set_weights(l3.layers[0].get_weights()) # copy conv weights
# not copying batch norm weights for now

In [305]:
mxp3_out = l3(conv3_inp)
mxp3_out.shape

TensorShape([1000, 4, 200])

In [306]:
def f():
    l3(tf.concat([mxp2_out_ref[:, mxp2_out_slice[0]:mxp2_out_slice[0]+mxp2_out_offset],
                      mxp2_out,
                      mxp2_out_ref[:, mxp2_out_slice[0]+mxp2_out_offset+mxp2_out.shape[1]:mxp2_out_slice[1]]],
                     axis=1))

In [283]:
%timeit f()

39.6 ms ± 547 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [307]:
mxp3_ism_out = l3_w_padding.predict(mxp2_ism_out)
mxp3_ism_out.shape

(1000, 20, 200)

In [285]:
%timeit l3_w_padding.predict(mxp2_ism_out)

222 ms ± 3.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Compare

In [308]:
tf.reduce_all(tf.equal(mxp3_out, mxp3_ism_out[:,range(*mxp3_out_range)]))

<tf.Tensor: shape=(), dtype=bool, numpy=True>

---

In [309]:
mxp3_out_ref = l3_w_padding.predict(mxp2_out_ref)
mxp3_out_ref.shape

(1000, 20, 200)

In [310]:
# next layer is FC layer, can be treated as conv with filter=width, no padding, no maxpool (maxpool width 1)
([mxp3_out_slice], [mxp3_out_offset]), _ = get_idxs_conv_maxpool(20, 20, 0, 1, [mxp3_out_range])

In [312]:
conv4_inp_width = mxp3_out_slice[1]-mxp3_out_slice[0]
print(conv4_inp_width)

20


In [313]:
conv4_inp_num_channels = mxp3_out.shape[2]
conv4_inp_num_channels

200

In [314]:
mxp3_out_offset

0

In [315]:
mxp3_out_slice

(0, 20)

In [317]:
conv4_inp = tf.concat([mxp3_out_ref[:, mxp3_out_slice[0]:mxp3_out_slice[0]+mxp3_out_offset],
                      mxp3_out,
                      mxp3_out_ref[:, mxp3_out_slice[0]+mxp3_out_offset+mxp3_out.shape[1]:mxp3_out_slice[1]]],
                     axis=1)
conv4_inp.shape

TensorShape([1000, 20, 200])

### Check with ISM conv3 output

In [318]:
tf.reduce_all(tf.equal(conv4_inp, mxp3_ism_out))

<tf.Tensor: shape=(), dtype=bool, numpy=True>

## FCs (Fully Connected Layers)

In [319]:
fcs = tf.keras.Sequential()
fcs.add(tf.keras.layers.Dense(1000, activation='relu'))
fcs.add(tf.keras.layers.BatchNormalization())
fcs.add(tf.keras.layers.Dense(1000, activation='relu'))
fcs.add(tf.keras.layers.BatchNormalization())
fcs.add(tf.keras.layers.Dense(10))
fcs.build(input_shape=(None,4000))

In [323]:
# make output -> reshape -> fc
def f():
    fcs.predict(tf.reshape(tf.concat([mxp3_out_ref[:, mxp3_out_slice[0]:mxp3_out_slice[0]+mxp3_out_offset],
                      mxp3_out,
                      mxp3_out_ref[:, mxp3_out_slice[0]+mxp3_out_offset+mxp3_out.shape[1]:mxp3_out_slice[1]]],
                     axis=1),
               (-1,4000)))

In [325]:
%timeit f()

122 ms ± 2.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [326]:
# only reshape -> fc
def f():
    fcs.predict(tf.reshape(mxp3_ism_out, (-1,4000)))

In [327]:
%timeit f()

109 ms ± 1.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Standoff

In [702]:
# without optimisations
def normalISMModel():
    inp = tf.keras.Input(shape=(1000,4))

    # conv mxp 1
    x = tf.keras.layers.Conv1D(300, 19, strides=1, padding='same', name='conv1')(inp)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = tf.keras.layers.MaxPool1D(3)(x)
    
    # conv mxp 2
    x = tf.keras.layers.Conv1D(200, 11, strides=1, padding='same', name='conv2')(x)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = tf.keras.layers.MaxPool1D(4)(x)
    
    # conv mxp 3
    x = tf.keras.layers.Conv1D(200, 7, strides=1, padding='same', name='conv3')(x)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = tf.keras.layers.MaxPool1D(4)(x)
    
    # fc
    x = tf.keras.layers.Reshape((4000,))(x)
    x = tf.keras.layers.Dense(1000, activation='relu', name='fc1')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(1000, activation='relu', name='fc2')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(10, name='fc3')(x)
    
    model = tf.keras.Model(inputs=inp, outputs=x, name='normalISM')
    
    return model

In [703]:
normal_ISM_model = normalISMModel()

In [704]:
type(inp_seq_perturbed)

tensorflow.python.framework.ops.EagerTensor

In [705]:
def run_normal():
    return normal_ISM_model.predict_on_batch(inp_seq_perturbed)

In [527]:
%timeit run_normal()

1.97 s ± 38.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [708]:
class SliceAssign(tf.keras.layers.Layer): 
    def __init__(self, b_dim): 
        super(SliceAssign, self).__init__() 
        
        # after one slice assign, tf can't calculate dimension 
        # since i is not known. So manually specify b_dim
        self.b_dim = b_dim
    
    def call(self, inputs):
        # GOAL: a[:,i:i+b.shape[1]] = b

        a, b, i = inputs
        # output will lose shape info (dim 1 will be set to None)
        return tf.concat([a[:,:i[0]], b, a[:,i[0]+self.b_dim:]], axis=1) 
    
    
def fastISMModel():
    inp = tf.keras.Input(shape=(inp_width,4))
    padded_mxp1_out_ref = tf.keras.Input(shape=(conv2_inp_width, conv2_inp_num_channels))
    mxp1_out_offset = tf.keras.Input(batch_size=1, shape=(), dtype='int32')
    
    padded_mxp2_out_ref = tf.keras.Input(shape=(conv3_inp_width, conv3_inp_num_channels))
    mxp2_out_offset = tf.keras.Input(batch_size=1, shape=(), dtype='int32')
    
    mxp3_out_ref = tf.keras.Input(shape=(conv4_inp_width, conv4_inp_num_channels))
    mxp3_out_offset = tf.keras.Input(batch_size=1, shape=(), dtype='int32')
    
    # conv mxp 1
    x = tf.keras.layers.Conv1D(300, 19, strides=1, padding='valid', name='conv1')(inp)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = tf.keras.layers.MaxPool1D(3)(x)
    
    # slice assign
    x = SliceAssign(mxp1_out_range[1]-mxp1_out_range[0])([padded_mxp1_out_ref, x, mxp1_out_offset])

    # conv mxp 2
    x = tf.keras.layers.Conv1D(200, 11, strides=1, padding='valid', name='conv2')(x)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = tf.keras.layers.MaxPool1D(4)(x)
    
    # slice assign
    x = SliceAssign(mxp2_out_range[1]-mxp2_out_range[0])([padded_mxp2_out_ref, x, mxp2_out_offset])
    
    # conv mxp 3
    x = tf.keras.layers.Conv1D(200, 7, strides=1, padding='valid', name='conv3')(x)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = tf.keras.layers.MaxPool1D(4)(x)
    
    # slice assign
    x = SliceAssign(mxp3_out_range[1]-mxp3_out_range[0])([mxp3_out_ref, x, mxp3_out_offset])
    
    # fc
    x = tf.keras.layers.Reshape((4000,))(x)
    x = tf.keras.layers.Dense(1000, activation='relu', name='fc1')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(1000, activation='relu', name='fc2')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(10, name='fc3')(x)
    
    model = tf.keras.Model(inputs=[inp, 
                                   padded_mxp1_out_ref, mxp1_out_offset,
                                   padded_mxp2_out_ref, mxp2_out_offset,
                                   mxp3_out_ref, mxp3_out_offset], 
                           outputs=x, name='fastISM')
    
    return model

In [709]:
fast_ISM_model = fastISMModel()

In [710]:
# inputs 
type(padded_mxp1_out_ref)

tensorflow.python.framework.ops.EagerTensor

In [711]:
padded_inp_seq[:, s_slice[0]:s_slice[1], :].shape

TensorShape([1000, 39, 4])

In [712]:
padded_mxp1_out_ref.shape

TensorShape([1000, 343, 300])

In [713]:
def run_fast():
    return fast_ISM_model.predict_on_batch([padded_inp_seq[:, s_slice[0]:s_slice[1], :],
                       padded_mxp1_out_ref[:, mxp1_out_slice[0]:mxp1_out_slice[1]], tf.ones(1)*mxp1_out_offset, 
                       padded_mxp2_out_ref[:, mxp2_out_slice[0]:mxp2_out_slice[1]], tf.ones(1)*mxp2_out_offset, 
                       mxp3_out_ref, tf.ones(1)*mxp3_out_offset])

In [574]:
%timeit run_fast()

207 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [714]:
# run so they build weights if haven't
run_normal()
run_fast()

# set weights to those from initial layers so that tensors like `padded_mxp1_out_ref` can be reused
fast_ISM_model.get_layer("conv1").set_weights(l1.layers[0].get_weights())
normal_ISM_model.get_layer("conv1").set_weights(l1.layers[0].get_weights())

fast_ISM_model.get_layer("conv2").set_weights(l2.layers[0].get_weights())
normal_ISM_model.get_layer("conv2").set_weights(l2.layers[0].get_weights())

fast_ISM_model.get_layer("conv3").set_weights(l3.layers[0].get_weights())
normal_ISM_model.get_layer("conv3").set_weights(l3.layers[0].get_weights())

# fcs
fast_ISM_model.get_layer("fc1").set_weights(fcs.layers[0].get_weights())
normal_ISM_model.get_layer("fc1").set_weights(fcs.layers[0].get_weights())

fast_ISM_model.get_layer("fc2").set_weights(fcs.layers[2].get_weights())
normal_ISM_model.get_layer("fc2").set_weights(fcs.layers[2].get_weights())

fast_ISM_model.get_layer("fc3").set_weights(fcs.layers[4].get_weights())
normal_ISM_model.get_layer("fc3").set_weights(fcs.layers[4].get_weights())

In [715]:
tf.reduce_all(tf.equal(run_normal(), run_fast()))

<tf.Tensor: shape=(), dtype=bool, numpy=True>