In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models

%matplotlib inline

## Dataset

In [2]:
%%capture
fashion_dataset = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_dataset.load_data()

In [3]:
print(train_images.shape)
print(train_labels.shape)

(60000, 28, 28)
(60000,)


In [4]:
print("BEFORE NORMALIZING:")
print("IMAGES: ",train_images.max()) # RANGE: 0-255
print("LABELS: ",train_labels.max()) # RANGE: 0-9

BEFORE NORMALIZING:
IMAGES:  255
LABELS:  9


In [5]:
train_images = train_images / 255.0
test_images = test_images / 255.0

print("AFTER NORMALIZING:")
print("IMAGES: ",train_images.max()) # RANGE: 0-1
print("IMAGES: ",test_images.max()) # RANGE: 0-1

AFTER NORMALIZING:
IMAGES:  1.0
IMAGES:  1.0


In [6]:
# plt.figure(figsize=(2,2))
# plt.imshow(train_images[100])

In [7]:
# CONVERT (x,y) -> (x,y,c)
train_images = train_images.reshape(*train_images.shape, 1).astype(np.float32)
test_images = test_images.reshape(*test_images.shape, 1).astype(np.float32)
print(train_images.shape)
print(test_images.shape)

(60000, 28, 28, 1)
(10000, 28, 28, 1)


## Network

In [8]:
class ConvUnit(layers.Layer):
    def __init__(self, out_channels, kernel_size):              # "kernel_size" - 3 or (3,3) | "input_shape" passed to first block
        super().__init__()
        self.conv = layers.Conv2D(out_channels, kernel_size)
        self.bn   = layers.BatchNormalization()
        
    def call(self, input_tensor, training=False):               # "training" - depends on FIT or EVALUATE (BN or DROPOUT)
        t = self.conv(input_tensor)
        t = self.bn(t, training=training)                       # "training" - passed to BN or DROPOUT if present
        t = tf.nn.relu(t)                                       # custom ACTIVATION
        return t
    
class LinearUnit(layers.Layer):
    def __init__(self, out_size, activation):
        super().__init__()
        self.fc = layers.Dense(out_size, activation=activation) # layers ACTIVATION
    
    def call(self, input_tensor):
        t = self.fc(input_tensor)
        return t

In [9]:
class CustomModel(models.Model):
    def __init__(self, image_shape):
        super().__init__()
        self.conv1 = ConvUnit(32, 3)                       # activated - custom
        self.conv2 = ConvUnit(64, 3)
        self.fc1 = LinearUnit(1024, activation='relu')     # activated - layers
        self.fc2 = LinearUnit(128 , activation='relu')
        self.out = LinearUnit(10  , activation='sigmoid')
        # GENERATE SUMMARY
        self.image_shape = image_shape
        self.build(input_shape=(None, *image_shape))
        
    def call(self, input_tensor, training=False):
        t = self.conv1(input_tensor, training=training)    # has BN
        t = self.conv2(t, training=training)
        t = layers.Flatten()(t)                            # FLATTEN returns callable
        t = self.fc1(t)
        t = self.fc2(t)
        t = self.out(t)
        return t
    
    def model(self):
        t = keras.Input(shape=self.image_shape)
        return keras.Model(inputs=[t], outputs=self.call(t))
        

In [10]:
model = CustomModel(image_shape=(28,28,1))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy', # from_logits = False (DEFAULT) ## EXPLORE ##
              metrics=['accuracy'])

In [11]:
model.model().summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv_unit (ConvUnit)         (None, 26, 26, 32)        448       
_________________________________________________________________
conv_unit_1 (ConvUnit)       (None, 24, 24, 64)        18752     
_________________________________________________________________
flatten (Flatten)            (None, 36864)             0         
_________________________________________________________________
linear_unit (LinearUnit)     (None, 1024)              37749760  
_________________________________________________________________
linear_unit_1 (LinearUnit)   (None, 128)               131200    
_________________________________________________________________
linear_unit_2 (LinearUnit)   (None, 10)               