In [282]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
import tensorflow_datasets as tfds

In [283]:
class ZeroMask():
    """
    when this class is initialized, it will will initilize a mask. When the instance is called an input 
    array the same shape as the mask will be multiplied with the mask and returned. The returned array will
    be masked to 0 at indecies specified in the "zero_indices" input. 
    
    Having this as a call able class, lets us repeativly use a mask multiple times in a cleanly abstracted way.
    """
    def __init__(self, shape, zero_indices):
        """
        inputs:
            shape: shape of matrix, ei (channels, filters, kernals, x_dims, y_dims)
            zero_indices: list of tuples that specify indexes to be masked to 0, ei [(row, col, channel, filter)]
                if zero_indices is equal to [(r, col, ch, f)] then mask[r, col, ch, f] will be set to 0
        instance variables
        """
        self.mask = np.ones(shape)
        for index in zero_indices:
            self.mask[index] = 0.  
    def __call__(self, arr):
        assert arr.shape == self.mask.shape, "Incorrect dims, mask shape: {}, arr shape{}".format(self.mask.shape, arr.shape)
        return self.mask * arr


In [284]:
def pattern_maker(filters, patterns, pattern_freq, kernel_size = (3,3)): #channels is limited to one here
    channel = 0
    assert filters == sum(pattern_freq)
    zero_indices = []
    filter_n = 0
    for pattern, freq in zip(patterns, pattern_freq):
        for i in range(kernel_size[0]):
            for j in range(kernel_size[1]):
                if pattern[i][j] == 0.:
                    for _ in range(freq):
                        zero_indices.append((i, j, channel, filter_n + _))
        filter_n += freq
        
    return zero_indices

patterns = [[[0, 1, 0], [0, 1, 1], [0, 1, 0]],
            [[0, 1, 0], [1, 1, 1], [0, 0, 0]], 
            [[0, 0, 0], [1, 1, 1], [0, 1, 0]],
            [[0, 1, 0], [1, 1, 0], [0, 1, 0]], 
            [[1, 0, 0], [1, 1, 0], [1, 0, 0]],
            [[1, 1, 0], [0, 1, 0], [0, 1, 0]],
            [[0, 0, 1], [0, 1, 1], [0, 0, 1]],
            [[1, 1, 1], [0, 1, 0], [0, 0, 0]],
           [[0, 0, 0], [0, 1, 0], [1, 1, 1]],
           [[0, 0, 1], [0, 1, 0], [1, 0, 1]]]
pattern_freq = [4]*10
zero_indices = pattern_maker(4*10, patterns, pattern_freq)
    
print(zero_indices)

[(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 0, 2), (0, 0, 0, 3), (0, 2, 0, 0), (0, 2, 0, 1), (0, 2, 0, 2), (0, 2, 0, 3), (1, 0, 0, 0), (1, 0, 0, 1), (1, 0, 0, 2), (1, 0, 0, 3), (2, 0, 0, 0), (2, 0, 0, 1), (2, 0, 0, 2), (2, 0, 0, 3), (2, 2, 0, 0), (2, 2, 0, 1), (2, 2, 0, 2), (2, 2, 0, 3), (0, 0, 0, 4), (0, 0, 0, 5), (0, 0, 0, 6), (0, 0, 0, 7), (0, 2, 0, 4), (0, 2, 0, 5), (0, 2, 0, 6), (0, 2, 0, 7), (2, 0, 0, 4), (2, 0, 0, 5), (2, 0, 0, 6), (2, 0, 0, 7), (2, 1, 0, 4), (2, 1, 0, 5), (2, 1, 0, 6), (2, 1, 0, 7), (2, 2, 0, 4), (2, 2, 0, 5), (2, 2, 0, 6), (2, 2, 0, 7), (0, 0, 0, 8), (0, 0, 0, 9), (0, 0, 0, 10), (0, 0, 0, 11), (0, 1, 0, 8), (0, 1, 0, 9), (0, 1, 0, 10), (0, 1, 0, 11), (0, 2, 0, 8), (0, 2, 0, 9), (0, 2, 0, 10), (0, 2, 0, 11), (2, 0, 0, 8), (2, 0, 0, 9), (2, 0, 0, 10), (2, 0, 0, 11), (2, 2, 0, 8), (2, 2, 0, 9), (2, 2, 0, 10), (2, 2, 0, 11), (0, 0, 0, 12), (0, 0, 0, 13), (0, 0, 0, 14), (0, 0, 0, 15), (0, 2, 0, 12), (0, 2, 0, 13), (0, 2, 0, 14), (0, 2, 0, 15), (1, 2, 0, 12), (1, 2, 0, 13),

In [285]:
mask_func = ZeroMask(shape = (3, 3, 1, 40), zero_indices = zero_indices)

In [286]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [287]:
#custom conv layer
class IregConv2D(tf.keras.Model): # still needs to figure out backprop, keep kernal wieghts at 0!!.
    def __init__(self, zero_mask, filters, kernel_size, *args, **kwargs):
        super(IregConv2D, self).__init__()
        self.conv = Conv2D(filters = filters, kernel_size = kernel_size, *args, **kwargs)
        self.zero_mask = zero_mask
        
    def call(self, x):
        return self.conv(x)
    
    def build(self, input_shape):
        super(IregConv2D, self).build(input_shape)
        weights = self.conv.get_weights()
        weights[0] = self.zero_mask(weights[0]) # set weights to 0 here
        self.conv.set_weights(weights)
        
    def get_conv(self):
        return self.conv
    
    
class TestNet(tf.keras.Model):
    def __init__(self):
        super(TestNet, self).__init__()
        self.conv = IregConv2D(mask_func, filters = 40, kernel_size = (3, 3), activation = 'relu')
        self.flatten = Flatten()
        self.dense = Dense(32, activation = 'relu')
        self.head = Dense(10, activation = 'softmax')
    
    def call(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.head(x)
        return x

In [288]:
x = tf.keras.Input(shape = (28, 28))
model = TestNet()

In [289]:
def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets, training=True)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

In [290]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
def loss(model, x, y, training):
    # training=training is needed only if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    y_ = model(x, training=training)

    return loss_object(y_true=y, y_pred=y_)
optimizer = Adam(learning_rate = .01)

In [291]:
train_loss_results = []
train_accuracy_results = []

num_epochs = 5

for epoch in range(num_epochs):
    epoch_loss_avg = tf.keras.metrics.Mean()
    epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    
    # Training loop - using batches of 32
    for x, y in ds_train:
        # Optimize the model
        loss_value, grads = grad(model, x, y)
        grads[0] = mask_func(grads[0]) ###this is the important part here!!, needs to become robust
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        # Track progress
        epoch_loss_avg.update_state(loss_value)  # Add current batch loss
        epoch_accuracy.update_state(y, model(x, training=True))

    # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())

    print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(epoch,
                                                                epoch_loss_avg.result(),
                                                                epoch_accuracy.result()))

Epoch 000: Loss: 0.244, Accuracy: 93.995%
Epoch 001: Loss: 0.069, Accuracy: 98.640%
Epoch 002: Loss: 0.044, Accuracy: 99.243%
Epoch 003: Loss: 0.034, Accuracy: 99.510%
Epoch 004: Loss: 0.029, Accuracy: 99.617%
