# Transfer Learning: MobileNetV2

We will use a pre-trained MobileNetV2 with "imageNet" weights. 

### imports

In [4]:
import numpy as np
from PIL import Image

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D, Dropout, Dense, MaxPool2D, Flatten
from tensorflow.keras.callbacks import EarlyStopping

import matplotlib.pyplot as plt


### Generators for loading and augmenting images

In [6]:
train_datagen = ImageDataGenerator(rescale=1./255,
                                     shear_range=0.1,
                                     zoom_range=0.1,
                                     rotation_range=5,
                                     width_shift_range=0.1,
                                     height_shift_range=0.1,
                                     fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory('../../data/train',
                                                    target_size=(150, 150),
                                                    batch_size=32,
                                                    class_mode='categorical')
test_generator = test_datagen.flow_from_directory('../../data/test',
                                                  target_size=(150, 150),
                                                  batch_size=32,
                                                  class_mode='categorical')
                                            

Found 5144 images belonging to 3 classes.
Found 1288 images belonging to 3 classes.


### Model

In [7]:
class XrayModel(tf.keras.Model):
    def __init__(self):
        super(XrayModel, self).__init__()

        self.base_model = tf.keras.applications.MobileNetV2(input_shape=(150, 150, 3),
                                                            include_top=False,
                                                            weights='imagenet')
        for layer in self.base_model.layers:
            layer.trainable = False
        for layer in self.base_model.layers[-10:]:
            layer.trainable = True

        self.pool = MaxPool2D()
        self.flatten = Flatten()
        self.dropout = Dropout(0.2)
        self.classifier = Dense(3, activation='softmax')
    
    def call(self, inputs):
        x = self.base_model(inputs)
        x = self.pool(x)
        x = self.flatten(x)
        outputs = self.classifier(x)
        return outputs
    
model = XrayModel()




2022-04-05 20:07:42.129801: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-04-05 20:07:42.130912: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-04-05 20:07:42.132998: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


## Load the saved model

In [None]:
model = tf.keras.models.load_model('../mobileNetV2')

### Compile and train

In [22]:
model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5),
              metrics='accuracy')


In [23]:
class MyCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if logs.get('accuracy') > 0.99:
            self.model.stop_training = True


In [None]:
model.fit(train_generator,
         steps_per_epoch=100,
         epochs=15,
         validation_data=test_generator,
         validation_steps=3,
         callbacks=[MyCallback(), 
                    EarlyStopping(monitor='val_loss', mode='min',
                                  verbose=1, patience=5)])

## Evaluate

In [25]:
model.evaluate(test_generator)



[3.988863229751587, 0.8330745100975037]

## Save

In [26]:
model.save('../models/mobileNetV2', save_format='tf')



2022-04-05 21:26:51.361096: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ../models/mobileNetV2/assets
