In [None]:
!jupyter nbextension enable --py widgetsnbextension
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
    Layer,
)
import tensorflow_datasets as tfds

@tf.keras.utils.register_keras_serializable()
class DiagonalwiseSeparableLayer(Layer):

    def __init__(self, kernel_size, out_channels, stride, padding, group_size, **kwargs):
        super(DiagonalwiseSeparableLayer, self).__init__(**kwargs)
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.stride = stride
        self.padding = padding
        self.group_size = group_size

    def get_mask(self, in_channels, kernel_size):
        mask = np.zeros((kernel_size, kernel_size, in_channels, in_channels))
        for _ in range(in_channels):
            mask[:, :, _, _] = 1.
        return tf.constant(mask, dtype='float32')

    def build(self, input_shape):
        #diagonalwise
        self.in_channels = input_shape[-1]
        self.groups = int(max(self.in_channels / self.group_size, 1))
        channels = int(self.in_channels / self.groups)

        self.mask = self.get_mask(channels, self.kernel_size)

        self.splitw = [self.add_weight(name = "diagwConv"+str(i), shape=(self.kernel_size, self.kernel_size, channels, channels), trainable=True) for i in range(self.groups)]

        #pointwise
        self.pw = self.add_weight(name = "pointwConv", shape = (1, 1, channels, self.out_channels), trainable=True)

    @tf.function
    def call(self, x):
        #diagonalwise
        splitx = tf.split(x, self.groups, -1)
        splitx = [tf.nn.conv2d(x, tf.multiply(w, self.mask), (1, self.stride, self.stride, 1), self.padding)
                  for x, w in zip(splitx, self.splitw)]
        x = tf.concat(splitx, -1)

        # pointwise
        x = tf.nn.conv2d(x, self.pw, (1, 1, 1, 1), self.padding)
        return x

    def get_config(self):
        config = super(DiagonalwiseSeparableLayer, self).get_config()
        config.update(
            {'kernel_size': self.kernel_size,
            'out_channels': self.out_channels,
            'group_size': self.group_size,
            'stride': self.stride,
            'padding': self.padding})
        return config

In [None]:
(ds_test, ds_train), info = tfds.load('caltech101', split=['train', 'test'], as_supervised=True, shuffle_files=True, with_info=True)
#print(ds_train)
#print(info)
print('Num Examples: ', info.splits['test'].num_examples)

#fig = tfds.show_examples(ds_train, info)

size=320

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  image = tf.image.resize(image, (size, size))
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(info.splits['train'].num_examples)
ds_train = ds_train.batch(32)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.cache()
ds_test = ds_test.batch(32)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)


In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Input([size, size, 3]),
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(256, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(128, (1,1), activation='relu'),
    tf.keras.layers.MaxPooling2D(9,9),
    tf.keras.layers.Conv2D(102, (2,2), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(102)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.summary()

t0 = time.time()

history = model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

t1 = time.time()

print(history)
print(f"* Conv2D elapsed {t1-t0}s")

Model: "sequential_30"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_85 (Conv2D)          (None, 318, 318, 32)      896       
                                                                 
 max_pooling2d_140 (MaxPooli  (None, 159, 159, 32)     0         
 ng2D)                                                           
                                                                 
 conv2d_86 (Conv2D)          (None, 157, 157, 64)      18496     
                                                                 
 max_pooling2d_141 (MaxPooli  (None, 78, 78, 64)       0         
 ng2D)                                                           
                                                                 
 conv2d_87 (Conv2D)          (None, 76, 76, 128)       73856     
                                                                 
 max_pooling2d_142 (MaxPooli  (None, 38, 38, 128)    

In [None]:
model2 = tf.keras.models.Sequential([
    tf.keras.layers.Input([size, size, 3]),
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.SeparableConv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.SeparableConv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.SeparableConv2D(256, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.SeparableConv2D(128, (1,1), activation='relu'),
    tf.keras.layers.MaxPooling2D(9,9),
    tf.keras.layers.SeparableConv2D(102, (2,2), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(102)
])
model2.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model2.summary()

t0 = time.time()

history = model2.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

t1 = time.time()

print(history)
print(f"* Separable Conv2D elapsed {t1-t0}s")

Model: "sequential_31"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_91 (Conv2D)          (None, 318, 318, 32)      896       
                                                                 
 max_pooling2d_145 (MaxPooli  (None, 159, 159, 32)     0         
 ng2D)                                                           
                                                                 
 separable_conv2d_46 (Separa  (None, 157, 157, 64)     2400      
 bleConv2D)                                                      
                                                                 
 max_pooling2d_146 (MaxPooli  (None, 78, 78, 64)       0         
 ng2D)                                                           
                                                                 
 separable_conv2d_47 (Separa  (None, 76, 76, 128)      8896      
 bleConv2D)                                          

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Input([size, size, 3]),
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    DiagonalwiseSeparableLayer(3, 64, stride=1, padding='SAME', group_size=64),
    tf.keras.layers.ReLU(),
    tf.keras.layers.MaxPooling2D(2,2),
    DiagonalwiseSeparableLayer(3, 128, stride=1, padding='SAME', group_size=64),
    tf.keras.layers.ReLU(),
    tf.keras.layers.MaxPooling2D(2,2),
    DiagonalwiseSeparableLayer(3, 256, stride=1, padding='SAME', group_size=64),
    tf.keras.layers.ReLU(),
    tf.keras.layers.MaxPooling2D(2,2),
    DiagonalwiseSeparableLayer(1, 128, stride=1, padding='SAME', group_size=64),
    tf.keras.layers.ReLU(),
    tf.keras.layers.MaxPooling2D(9,9),
    tf.keras.layers.SeparableConv2D(102, (2,2), activation='relu'),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(102)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.summary()


t0 = time.time()

history = model2.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

t1 = time.time()

print(history)
print(f"* Diagonalwise Separable Conv2D elapsed {t1-t0}s")

Model: "sequential_32"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_92 (Conv2D)          (None, 318, 318, 32)      896       
                                                                 
 max_pooling2d_150 (MaxPooli  (None, 159, 159, 32)     0         
 ng2D)                                                           
                                                                 
 diagonalwise_separable_laye  (None, 159, 159, 64)     11264     
 r_22 (DiagonalwiseSeparable                                     
 Layer)                                                          
                                                                 
 re_lu_12 (ReLU)             (None, 159, 159, 64)      0         
                                                                 
 max_pooling2d_151 (MaxPooli  (None, 79, 79, 64)       0         
 ng2D)                                               