In [7]:
import tensorflow as tf 
import tensorflow.keras
import numpy as np
from tensorflow.keras.layers import Input, BatchNormalization, MaxPool2D, Conv2D, Dropout, Flatten, Dense,OctaveConv2D
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import fashion_mnist


(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = tf.expand_dims(x_train.astype(np.float32) / 255, axis=-1)
x_test = tf.expand_dims(x_test.astype(np.float32) / 255, axis=-1)

y_train, y_test = tf.expand_dims(y_train, axis=-1), tf.expand_dims(y_test, axis=-1)

train_num = round(x_train.shape[0] * 0.9)
x_train, x_valid = x_train[:train_num, ...], x_train[train_num:, ...]
y_train, y_valid = y_train[:train_num, ...], y_train[train_num:, ...]


# Octave Conv
inputs = Input(shape=(28, 28, 1))
normal = BatchNormalization()(inputs)
high, low = OctaveConv2D(64, kernel_size=3)(normal)
high, low = MaxPool2D()(high), MaxPool2D()(low)
high, low = OctaveConv2D(32, kernel_size=3)([high, low])
conv = OctaveConv2D(16, kernel_size=3, ratio_out=0.0)([high, low])
pool = MaxPool2D()(conv)
flatten = Flatten()(pool)
normal = BatchNormalization()(flatten)
dropout = Dropout(rate=0.4)(normal)
outputs = Dense(units=10, activation='softmax')(dropout)
model = Model(inputs=inputs, outputs=outputs)
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
)

model.summary()
model.fit(
    x=x_train,
    y=y_train,
    epochs=10,
    validation_data=(x_valid, y_valid)
)
octave_score = model.evaluate(x_test, y_test)
print('Accuracy of Octave: %.4f' % octave_score[1])
"""
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 28, 28, 1)    4           input_1[0][0]                    
__________________________________________________________________________________________________
octave_conv2d (OctaveConv2D)    [(None, 28, 28, 32), 640         batch_normalization[0][0]        
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 14, 14, 32)   0           octave_conv2d[0][0]              
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 7, 7, 32)     0           octave_conv2d[0][1]              
__________________________________________________________________________________________________
octave_conv2d_1 (OctaveConv2D)  [(None, 14, 14, 16), 18496       max_pooling2d[0][0]              
                                                                 max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
octave_conv2d_2 (OctaveConv2D)  (None, 14, 14, 16)   4640        octave_conv2d_1[0][0]            
                                                                 octave_conv2d_1[0][1]            
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 7, 7, 16)     0           octave_conv2d_2[0][0]            
__________________________________________________________________________________________________
flatten (Flatten)               (None, 784)          0           max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 784)          3136        flatten[0][0]                    
__________________________________________________________________________________________________
dropout (Dropout)               (None, 784)          0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
dense (Dense)                   (None, 10)           7850        dropout[0][0]                    
==================================================================================================
Total params: 34,766
Trainable params: 33,196
Non-trainable params: 1,570
__________________________________________________________________________________________________
Epoch 1/10
1688/1688 [==============================] - 53s 31ms/step - loss: 0.5615 - accuracy: 0.8092 - val_loss: 0.3606 - val_accuracy: 0.8727
Epoch 2/10
1688/1688 [==============================] - 53s 32ms/step - loss: 0.4074 - accuracy: 0.8596 - val_loss: 0.3477 - val_accuracy: 0.8765
Epoch 3/10
1688/1688 [==============================] - 54s 32ms/step - loss: 0.3719 - accuracy: 0.8715 - val_loss: 0.2975 - val_accuracy: 0.8970
Epoch 4/10
1688/1688 [==============================] - 57s 34ms/step - loss: 0.3531 - accuracy: 0.8769 - val_loss: 0.2926 - val_accuracy: 0.8960
Epoch 5/10
1688/1688 [==============================] - 57s 34ms/step - loss: 0.3367 - accuracy: 0.8809 - val_loss: 0.3039 - val_accuracy: 0.8923
Epoch 6/10
1688/1688 [==============================] - 54s 32ms/step - loss: 0.3230 - accuracy: 0.8866 - val_loss: 0.2790 - val_accuracy: 0.8978
Epoch 7/10
1688/1688 [==============================] - 55s 33ms/step - loss: 0.3156 - accuracy: 0.8884 - val_loss: 0.3083 - val_accuracy: 0.8930
Epoch 8/10
1688/1688 [==============================] - 57s 34ms/step - loss: 0.3046 - accuracy: 0.8925 - val_loss: 0.2925 - val_accuracy: 0.8953
Epoch 9/10
1688/1688 [==============================] - 59s 35ms/step - loss: 0.2973 - accuracy: 0.8940 - val_loss: 0.2706 - val_accuracy: 0.9048
Epoch 10/10
1688/1688 [==============================] - 56s 33ms/step - loss: 0.2945 - accuracy: 0.8955 - val_loss: 0.2631 - val_accuracy: 0.9062
313/313 [==============================] - 2s 7ms/step - loss: 0.2849 - accuracy: 0.9015
Accuracy of Octave: 0.9015"""

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 28, 28, 1)    4           input_1[0][0]                    
__________________________________________________________________________________________________
octave_conv2d (OctaveConv2D)    [(None, 28, 28, 32), 640         batch_normalization[0][0]        
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 14, 14, 32)   0           octave_conv2d[0][0]              
______________________________________________________________________________________________