In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
import tensorflow_datasets as tfds
from time import time
import matplotlib.pyplot as plt 
tf.config.run_functions_eagerly(True)

In [69]:
%matplotlib inline

try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

if IN_COLAB:
    !pip install tvm
else:
    print("Notebook executing locally, skipping Colab setup ...")

Collecting tvm
  Downloading tvm-1.0.0.tar.gz (5.4 kB)
Collecting inform
  Downloading inform-1.26.0-py3-none-any.whl (47 kB)
[K     |████████████████████████████████| 47 kB 2.3 MB/s 
[?25hCollecting quantiphy
  Downloading quantiphy-2.15.0-py3-none-any.whl (34 kB)
Collecting arrow
  Downloading arrow-1.2.1-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 2.0 MB/s 
Building wheels for collected packages: tvm
  Building wheel for tvm (setup.py) ... [?25l[?25hdone
  Created wheel for tvm: filename=tvm-1.0.0-py3-none-any.whl size=5101 sha256=8ca58b2291a69e9fb0df5f02433f531b908a5b7b2641dea6a4ab2ff0e09c724f
  Stored in directory: /root/.cache/pip/wheels/2a/39/5e/4c2cdaa05641090b45744676013760972b6903d8e4a481b664
Successfully built tvm
Installing collected packages: arrow, quantiphy, inform, tvm
Successfully installed arrow-1.2.1 inform-1.26.0 quantiphy-2.15.0 tvm-1.0.0


In [70]:
import tvm

In [2]:

def zero_mask(shape, zero_indices):
    """
    inputs:
        shape: shape of matrix, ei (channels, filters, kernals, x_dims, y_dims)
        zero_indices: list of tuples that specify indexes to be masked to 0, ei [(row, col, channel, filter)]
            if zero_indices is equal to [(r, col, ch, f)] then mask[r, col, ch, f] will be set to 0
    instance variables
    """
    mask = np.ones(shape)
    for index in zero_indices:
        mask[index] = 0.
    return mask

In [3]:
def pattern_maker(filters, patterns, pattern_freq, kernel_size = (3,3)): #channels is limited to one here, make more robust 
    channel = 0
    assert filters == sum(pattern_freq)
    zero_indices = []
    filter_n = 0
    for pattern, freq in zip(patterns, pattern_freq):
        for i in range(kernel_size[0]):
            for j in range(kernel_size[1]):
                if pattern[i][j] == 0.:
                    for _ in range(freq):
                        zero_indices.append((i, j, channel, filter_n + _))
        filter_n += freq
        
    return zero_indices

patterns = [[[0, 1, 0], [0, 1, 1], [0, 1, 0]],
            [[0, 1, 0], [1, 1, 1], [0, 0, 0]], 
            [[0, 0, 0], [1, 1, 1], [0, 1, 0]],
            [[0, 1, 0], [1, 1, 0], [0, 1, 0]], 
            [[1, 0, 0], [1, 1, 0], [1, 0, 0]],
            [[1, 1, 0], [0, 1, 0], [0, 1, 0]],
            [[0, 0, 1], [0, 1, 1], [0, 0, 1]],
            [[1, 1, 1], [0, 1, 0], [0, 0, 0]],
           [[0, 0, 0], [0, 1, 0], [1, 1, 1]],
           [[0, 0, 1], [0, 1, 0], [1, 0, 1]]]
pattern_freq = [4]*10
zero_indices = pattern_maker(4*10, patterns, pattern_freq)
mask = zero_mask(shape = (3, 3, 1, 40), zero_indices = zero_indices)


In [4]:
batch_size = 64
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


  "Even though the `tf.config.experimental_run_functions_eagerly` "


In [44]:
class IregConv2D(tf.keras.layers.Layer): # still needs to figure out backprop, keep kernal wieghts at 0!!.
    def __init__(self, zero_mask, strides = 1, padding = "SAME", *args, **kwargs):
        super(IregConv2D, self).__init__()
        if type(strides) == int:
            strides = [1, strides, strides, 1]
        @tf.custom_gradient
        def conv2d_override(x, filters, strides, padding, data_format='NHWC', dilations=None, name=None):
            y = tf.nn.conv2d(x, filters, strides = strides, padding = "SAME", data_format='NHWC', dilations=None, name=None)
            def grad(upstream):
                grads = tf.compat.v1.nn.conv2d_backprop_filter(x, zero_mask.shape, out_backprop = upstream, strides = strides, padding = "SAME", data_format='NHWC', dilations=None, name=None)
                dydx = tf.compat.v1.nn.conv2d_backprop_input(tf.shape(x), filter = filters, out_backprop = upstream, strides = strides, padding = "SAME", data_format='NHWC', dilations=None, name=None)
                return dydx, tf.multiply(grads, self.zero_mask), None, None, None, None, None
            return y, grad
        
        self.zero_mask = tf.cast(zero_mask, dtype = tf.float32)
        self.strides = strides
        self.padding = padding

        w_init = tf.keras.initializers.HeNormal()
        w = w_init(shape = zero_mask.shape, dtype = 'float32') #note the sloppy zero_mask.shape // change this
        #self.b = self.add_weight(shape = zero_mask.shape[2:], initializer="he_normal", trainable=True) # sloppy, change this 
        w = tf.multiply(self.zero_mask, w)
        self.w = tf.Variable(initial_value=w, trainable=True)
        self.conv = conv2d_override
    
    def call(self, x):
        return self.conv(x, self.w, self.strides, self.padding)
    def get_conv(self):
        return self.conv

In [45]:
class TestNet(tf.keras.Model):
    def __init__(self):
        super(TestNet, self).__init__()
        self.conv = IregConv2D(mask)
        self.relu1 = ReLU()
        self.flatten = Flatten()
        self.dense = Dense(32, activation = 'relu')
        self.head = Dense(10, activation = 'softmax')
        
    def call(self, x):
        x = self.conv(x)
        x = self.relu1(x)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.head(x)
        return x

In [46]:
class BaseNet(tf.keras.Model):
    def __init__(self):
        super(BaseNet, self).__init__()
        self.conv = Conv2D(filters = 40, kernel_size = (3, 3), activation = 'relu')
        self.flatten = Flatten()
        self.dense = Dense(32, activation = 'relu')
        self.head = Dense(10, activation = 'softmax')
    
    def call(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.head(x)
        return x

In [47]:
ireg_model = TestNet()
reg_model = BaseNet()
num_epochs = 1
lr = 0.001

In [49]:
ireg_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = lr),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

ireg_history = ireg_model.fit(ds_train, epochs = num_epochs, validation_data = ds_test, verbose = 1)



In [50]:
reg_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = lr),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[ tf.keras.metrics.SparseCategoricalAccuracy()])

reg_history = reg_model.fit(ds_train, epochs = num_epochs, validation_data = ds_test, verbose = 2)

938/938 - 53s - loss: 0.2144 - sparse_categorical_accuracy: 0.9366 - val_loss: 0.1044 - val_sparse_categorical_accuracy: 0.9695


In [61]:
tvm_ds = ds_test.as_numpy_iterator()
ds_ = tvm_ds.next()[0][0]

In [72]:
import tvm.relay as relay

ireg_model.conv.w

In [71]:
#tvm_ireg = relay.frontend.from_keras(ireg_model)
tvm_reg = relay.frontend.from_keras(reg_model)
tvm_reg(ds_)

AttributeError: ignored