<a href="https://colab.research.google.com/github/mziad97/Transfer-learning-with-CIFAR10/blob/main/Transfer_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
(train_images, train_labels), (valid_images, valid_labels) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [None]:
def preprocess_image(input_image):
  input_image = input_image.astype('float32')
  output_ims = tf.keras.applications.resnet50.preprocess_input(input_image)
  return output_ims

In [None]:
train_x = preprocess_image(train_images)
valid_x = preprocess_image(valid_images)

In [None]:
train_labels[0]

array([6], dtype=uint8)

In [None]:
train_x.shape

(50000, 32, 32, 3)

In [None]:
def feature_extractor(inputs):
  feature_extractor = tf.keras.applications.resnet.ResNet50(input_shape=(224, 224, 3), 
                                                   include_top = False, weights='imagenet')(inputs)
  return feature_extractor


def classifier(inputs):
  x = layers.GlobalAveragePooling2D()(inputs)
  x = layers.Flatten()(x)
  x = layers.Dense(1024, activation='relu')(x)
  x = layers.Dense(512, activation='relu')(x)
  x = layers.Dense(10, activation='softmax', name='classification')(x)
  return x

def final_model(inputs):
  resize = layers.UpSampling2D(size=(7,7))(inputs)
  reset_feature_extractor = feature_extractor(resize)
  classification_output = classifier(reset_feature_extractor)

  return classification_output

In [None]:
def compile_model():

  inputs = layers.Input(shape=(32,32,3))
  classification_output = final_model(inputs)
  model = tf.keras.Model(inputs=inputs, outputs=classification_output)

  model.compile(optimizer='SGD', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

  return model

In [None]:
model = compile_model()
model.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 224, 224, 3)       0         
_________________________________________________________________
resnet50 (Functional)        (None, 7, 7, 2048)        23587712  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
flatten (Flatten)            (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 1024)         

In [None]:
class MyThreshold(tf.keras.callbacks.Callback):

  def __init__(self, threshold):
    super(MyThreshold, self).__init__()
    self.threshold = threshold
  
  def on_epoch_end(self, epoch, logs=None):
    val_accuracy = logs['val_accuracy']
    if (val_accuracy >= self.threshold):
      self.model.stop_training = True

In [None]:
EPOCHS = 4
myCallback = MyThreshold(threshold=0.9)
history = model.fit(train_x, train_labels, epochs=EPOCHS, validation_data = (valid_x, valid_labels), batch_size=64, callbacks=[myCallback])

Epoch 1/4
