# 5-4: Models with Conv Layers

## Code 5-4-1: Models with Sequential Method

In [2]:
import tensorflow as tf

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D

n_neurons = [10, 20, 30]

model = Sequential()
model.add(Conv2D(filters=n_neurons[0], kernel_size=3, activation='relu')) # (26,26,10)
model.add(Conv2D(filters=n_neurons[1], kernel_size=3, activation='relu')) # (24,24,20)
model.add(Conv2D(filters=n_neurons[2], kernel_size=3, activation='relu')) # (22,22,30)

x = tf.random.normal(shape=(32,28,28,3))
predictions = model(x)

print("Input: {}".format(x.shape))
print("Output: {}\n".format(predictions.shape))

for layer in model.layers:
  W, B = layer.get_weights()
  print(W.shape, B.shape)

print("====")
trainable_variables = model.trainable_variables
for train_var in trainable_variables:
  print(train_var.shape)

Input: (32, 28, 28, 3)
Output: (32, 22, 22, 30)

(3, 3, 3, 10) (10,)
(3, 3, 10, 20) (20,)
(3, 3, 20, 30) (30,)
====
(3, 3, 3, 10)
(10,)
(3, 3, 10, 20)
(20,)
(3, 3, 20, 30)
(30,)


## Code 5-4-2: Models with Model Sub-classing

In [12]:
import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D

n_neurons = [10, 20, 30]

class TestModel(Model):
  def __init__(self):
    super(TestModel, self).__init__()
    global n_neurons

    self.conv_layers = []
    for n_neuron in n_neurons:
      self.conv_layers.append(Conv2D(filters=n_neuron, kernel_size=3, activation='relu'))

  def call(self, x):
    print("Input: ", x.shape, '\n')

    print("==== Conv Layers =====")
    for conv_layer in self.conv_layers:
      x = conv_layer(x)
      W, B = conv_layer.get_weights()
      print("W/B: {}/{}".format(W.shape, B.shape))
      print("X: {}\n".format(x.shape))
    return x

model = TestModel()
x = tf.random.normal(shape=(32,28,28,3))
predictions = model(x)

Input:  (32, 28, 28, 3) 

==== Conv Layers =====
W/B: (3, 3, 3, 10)/(10,)
X: (32, 26, 26, 10)

W/B: (3, 3, 10, 20)/(20,)
X: (32, 24, 24, 20)

W/B: (3, 3, 20, 30)/(30,)
X: (32, 22, 22, 30)

