In [4]:
import tensorflow as tf
import tensorflow.keras.datasets as tfds
from utils.layer_units import *
import pydotplus

# Plot configurations
%matplotlib inline

# Notebook auto reloads code. (Ref: http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython)
%load_ext autoreload
%autoreload 2

In [5]:
# load cifar10 data
(X_train,y_train),(X_test,y_test) = tfds.cifar10.load_data()


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [6]:
ndist = tf.random.normal((1,64,64,3))
ipt = tf.keras.Input(shape=(32,32,3))
resid_layer = residual_units(ipt,[[64,3,1,0],[128,2,1,0],[66,3,1,0]])

In [7]:
resid_layer.get_shape

<bound method Tensor.get_shape of <tf.Tensor 'add/Identity:0' shape=(None, 32, 32, 66) dtype=float32>>

In [8]:
mock_model = tf.keras.Model(ipt,outputs=resid_layer)

In [9]:
mock_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 32, 32, 3)    12          input_1[0][0]                    
__________________________________________________________________________________________________
activation (Activation)         (None, 32, 32, 3)    0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 64)   1792        activation[0][0]                 
______________________________________________________________________________________________

In [10]:
ndist_pooled = tf.keras.layers.MaxPool2D(pool_size=(2,2))(ndist)
ndist_pooled.shape[-1]

3

In [11]:
[[2] for _ in range(3)]

[[2], [2], [2]]

In [12]:
hg_layer = hour_glass_unit(ipt,1,True)

In [13]:
mock_model = tf.keras.Model(ipt,outputs=hg_layer)
mock_model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 16, 16, 3)    0           input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 16, 16, 3)    12          max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 16, 16, 3)    0           batch_normalization_3[0][0]      
____________________________________________________________________________________________

In [14]:
dot_img_file = 'hour_glass.png'
tf.keras.utils.plot_model(mock_model,to_file=dot_img_file, show_shapes=True)

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


In [15]:
attention_module = attention_unit(ipt)

In [16]:
mock_model = tf.keras.Model(ipt,outputs=attention_module)
mock_model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 32, 32, 3)    12          input_1[0][0]                    
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 32, 32, 3)    0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 32, 32, 32)   896         activation_18[0][0]              
____________________________________________________________________________________________

In [17]:
dot_img_file = 'attention_unit.png'
tf.keras.utils.plot_model(mock_model,to_file=dot_img_file, show_shapes=True)

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


In [22]:
flat1 = tf.keras.layers.Flatten()(attention_module)
dense1 = tf.keras.layers.Dense(256,activation='relu')(flat1)
dense2 = tf.keras.layers.Dense(10,activation='softmax')(dense1)

model = tf.keras.Model(ipt,outputs=dense2)

model.compile(optimizer = 'Adam',
    loss = 'sparse_categorical_crossentropy',
    metrics = ['sparse_categorical_accuracy'],
)
model.summary()

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 32, 32, 3)    12          input_1[0][0]                    
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 32, 32, 3)    0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 32, 32, 32)   896         activation_18[0][0]              
____________________________________________________________________________________________

In [None]:
model.fit(X_train,y_train,epochs=10,batch_size=128,validation_split=0.2)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10