In [1]:
import os
import sys

# append PYTHONPATH to load extensions
sys.path.append('fastaugment')
sys.path.append('sigmoid_like_tf_op')

# load TensorFlow
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

# import other modules
import numpy
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.data.experimental import AUTOTUNE
import model
from fast_augment import augment, center_crop
from sigmoid_like import sigmoid_like

In [2]:
# Setup the model

model = model.make_model(input_size=385, activation=sigmoid_like, num_classes=120)

# Load pretrained 120-class model
model.load_weights('model.h5')

# Make all its layers non-trainable
for layer in model.layers:
    layer.trainable = False

# Replace last dense layer by a new one with 40 units on output
dense_input = model.get_layer('Dense').input
new_dense = tf.keras.layers.Dense(units=40, name='Dense')(dense_input)
model = tf.keras.models.Model(inputs=model.input, outputs=new_dense)

# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
          loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
          metrics=[tf.keras.metrics.CategoricalAccuracy()])

model.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 385, 385, 3) 0                                            
__________________________________________________________________________________________________
b0-conv (Conv2D)                (None, 191, 191, 32) 2400        input[0][0]                      
__________________________________________________________________________________________________
b0-conv-bn (BatchNormalization) (None, 191, 191, 32) 128         b0-conv[0][0]                    
__________________________________________________________________________________________________
b0-conv-act (Activation)        (None, 191, 191, 32) 0           b0-conv-bn[0][0]                 
_______________________________________________________________________________________

In [4]:
# Load dataset
image_size = model.layers[0].input_shape[0][1]

(train_set, test_set), info = tfds.load(
    'oxford_iiit_pet',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True)

crop_size = 6 * image_size // 5

train_set = train_set.shuffle(info.splits['train'].num_examples)
train_set = train_set.map(lambda x, y: (center_crop(x, size=[crop_size, crop_size]), y))
train_set = train_set.batch(64)
train_set = train_set.cache()
train_set = train_set.prefetch(AUTOTUNE)
train_set = train_set.map(lambda x, y: augment(tf.cast(x, tf.uint8),
                                               tf.one_hot(y, 40),
                                               output_size=[image_size, image_size],
                                               prescale=image_size/crop_size,
                                               mixup=0.5,
                                               perspective=20))

test_set = test_set.map(lambda x, y: (center_crop(x, size=[image_size, image_size]) / 255,
                                      tf.one_hot(y, 40)))
test_set = test_set.batch(64)
test_set = test_set.prefetch(AUTOTUNE)

In [None]:
# Sample augmented images from the training set
it = train_set.take(1).unbatch().as_numpy_iterator()

plt.figure(figsize=(12,12))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    image, _ = next(it)
    plt.imshow(image)
plt.show()

In [5]:
# FFEEEEAAAAATTTTTT

# define LR schedule
def lr_scheduler(epoch, _, initial_lr=0.1, cliff=10, step=10):
    step = max(0, (epoch - cliff) // step + 1)
    factor = 0.5
    lr = initial_lr * (factor ** step)
    return lr

model.fit(train_set,
          validation_data=test_set,
          validation_freq=1,
          epochs=100,
          callbacks=[tf.keras.callbacks.LearningRateScheduler(lr_scheduler)])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100


Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100


Epoch 98/100
Epoch 99/100
Epoch 100/100


<tensorflow.python.keras.callbacks.History at 0x7f1d40fc0dd8>

In [None]:
# Run validation pass
model.evaluate(test_set)