In [1]:
import tensorflow as tf
import numpy as np

In [2]:
input_dims = 20
modified_shape = 10
input_shape=[input_dims, 1]

In [42]:
def get_model(input_shape, is_kernel_one):
    kernel_initializer='random_uniform'
    if is_kernel_one:
        kernel_initializer='ones'
        
    inputs = tf.keras.Input(shape=input_shape)
    x = inputs
    x = tf.pad(x, paddings=[[0, 0], [0, 0], [0, 0]], mode='CONSTANT', constant_values=1)
    x = tf.keras.layers.Conv1D(filters=1, kernel_size=3, strides=1, 
                               padding='valid', 
                               kernel_initializer=kernel_initializer,
                               bias_initializer='zeros',
                               )(x)
    x = tf.pad(x, paddings=[[0, 0], [4, 0], [0, 0]], mode='CONSTANT', constant_values=1)
    x = tf.keras.layers.Conv1D(filters=1, kernel_size=3, strides=2, 
                               padding='valid', 
                               kernel_initializer=kernel_initializer,
                               bias_initializer='zeros',
                               )(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=x, name="test_model")
    
    return model

In [43]:
def get_input_tensor(input_shape, is_kernel_one):
    input_tensor = np.random.rand(*input_shape)[np.newaxis, ...]
    if is_kernel_one:
        input_tensor = np.zeros(input_shape)[np.newaxis, ...]
    return input_tensor

def get_output(model, input_tensor):
    output = np.array(model(input_tensor))
    return output.flatten()

### Receptive field of the network

In [44]:
def get_receptive_field(model):
    layer_wise = {}
    receptive_field = 1
    for layer in model.layers[::-1]:
        config = layer.get_config()
        stride = 1
        kernel_size = 1
        if "strides" in config:
            stride = config["strides"][0]
        if "kernel_size" in config:
            kernel_size = config["kernel_size"][0]
        receptive_field = receptive_field * stride + (kernel_size - stride)
        layer_wise[config["name"]] = receptive_field
    return receptive_field, layer_wise

In [45]:
is_kernel_one=False
input_tensor = get_input_tensor(input_shape, is_kernel_one)
model = get_model(input_shape, is_kernel_one=is_kernel_one)
receptive_field = get_receptive_field(model)[0]
print("Receptive field", receptive_field)

Receptive field 5


### Effective stride

In [46]:
def get_effective_stride(model):
    layer_wise = {}
    effective_stride = 1
    for layer in model.layers[::-1]:
        config = layer.get_config()
        stride = 1
        kernel_size = 1
        if "strides" in config:
            stride = config["strides"][0]
        effective_stride = effective_stride * stride
        layer_wise[config["name"]] = effective_stride
    return effective_stride, layer_wise

In [47]:
effective_stride = get_effective_stride(model)[0]
print("Effective stride", effective_stride)

Effective stride 2


### Effective padding

In [48]:
def get_effective_padding(model):
    layer_wise = {}
    effective_padding = 0
    for layer in model.layers[::-1]:
        config = layer.get_config()
        stride = 1
        padding = 0
        if "strides" in config:
            stride = config["strides"][0]
        if "padding" in config:
            if config["padding"] in ['valid']:
                padding = 0
        if "node_def" in config and "Pad" in config["node_def"]["name"]:
            padding = np.sum(config["constants"][1])
        effective_padding = stride * effective_padding + padding
        layer_wise[config["name"]] = effective_padding
    return effective_padding, layer_wise

In [49]:
effective_padding = get_effective_padding(model)[0]
print("Effective padding", effective_padding)

Effective padding 4


### Reference output

In [50]:
input_shape = [20, 1]
is_kernel_one = False
# is_kernel_one = True
input_tensor = get_input_tensor(input_shape, is_kernel_one)
model = get_model(input_shape, is_kernel_one=is_kernel_one)
model.summary()
ref_output = get_output(model, input_tensor)
print(ref_output)

Model: "test_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 20, 1)]           0         
_________________________________________________________________
tf_op_layer_PadV2_8 (TensorF [(None, 20, 1)]           0         
_________________________________________________________________
conv1d_8 (Conv1D)            (None, 18, 1)             4         
_________________________________________________________________
tf_op_layer_PadV2_9 (TensorF [(None, 22, 1)]           0         
_________________________________________________________________
conv1d_9 (Conv1D)            (None, 10, 1)             4         
Total params: 8
Trainable params: 8
Non-trainable params: 0
_________________________________________________________________
[-0.08486608 -0.08086114  0.00321689  0.00395104  0.00220385  0.00275447
  0.00191425  0.00336804  0.0033351   0.00103537]


### Influence of padding in the output

In [51]:
input_tensor_zeros = get_input_tensor(input_shape, is_kernel_one=True)
model_ref = get_model(input_shape, is_kernel_one=True)
pad_output = get_output(model_ref, input_tensor_zeros)
num_output_influenced = np.count_nonzero(pad_output)
print(pad_output)
print("Padding influenced", num_output_influenced)

[3. 2. 0. 0. 0. 0. 0. 0. 0. 0.]
Padding influenced 2


### How to find the overlap needed?
The overlap is the number of elements influenced by padding * receptive field.

In [64]:
overlap = num_output_influenced * effective_stride # this is wrong
print("Overlap", overlap)

Overlap 4


### Split input and compute inputs 

In [53]:
def modify_input(model, batch_input_shape):
    model_config = model.get_config()
    model_config["layers"][0]["config"]["batch_input_shape"] = batch_input_shape
    modified_model = tf.keras.Model.from_config(model_config)
    modified_model.set_weights([x.value() for x in model.weights])
    modified_model.summary()
    return modified_model

In [54]:
model.summary()
batch_input_shape = (None, modified_shape, 1)
modified_model = modify_input(model, batch_input_shape)

Model: "test_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 20, 1)]           0         
_________________________________________________________________
tf_op_layer_PadV2_8 (TensorF [(None, 20, 1)]           0         
_________________________________________________________________
conv1d_8 (Conv1D)            (None, 18, 1)             4         
_________________________________________________________________
tf_op_layer_PadV2_9 (TensorF [(None, 22, 1)]           0         
_________________________________________________________________
conv1d_9 (Conv1D)            (None, 10, 1)             4         
Total params: 8
Trainable params: 8
Non-trainable params: 0
_________________________________________________________________
Model: "test_model"
_________________________________________________________________
Layer (type)                 Output Shape 

In [55]:
input_tensor.shape

(1, 20, 1)

In [56]:
ref_output, ref_output.shape

(array([-0.08486608, -0.08086114,  0.00321689,  0.00395104,  0.00220385,
         0.00275447,  0.00191425,  0.00336804,  0.0033351 ,  0.00103537],
       dtype=float32),
 (10,))

In [63]:
output = []
output_dims = (input_dims + effective_padding - receptive_field) // effective_stride + 1
for idx in range(0, output_dims):
    if idx < len(output):
        continue
    
    start_idx = max(0, (idx-num_output_influenced) * effective_stride)
    
    sliced_input = input_tensor[:, start_idx:start_idx + modified_shape, :]
    partial_output = np.array(modified_model(sliced_input)).flatten()
    partial_output = partial_output
    if idx != 0:
        partial_output = partial_output[num_output_influenced:]
    output.extend(partial_output)
    print(idx, start_idx, sliced_input.shape, partial_output)
print(output)
print(np.isclose(ref_output, output))

0 0 (1, 10, 1) [-0.08486608 -0.08086114  0.00321689  0.00395104  0.00220385]
5 6 (1, 10, 1) [0.00275447 0.00191425 0.00336804]
8 12 (1, 8, 1) [0.0033351  0.00103537]
[-0.08486608, -0.080861144, 0.0032168918, 0.0039510448, 0.002203849, 0.0027544708, 0.0019142547, 0.0033680415, 0.0033350994, 0.0010353675]
[ True  True  True  True  True  True  True  True  True  True]
