In [7]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Sequential

from tensorflow.keras.models import load_model

from collections import Counter

PATH = '../Data/animals10/raw-img/'

WIDTH, HEIGHT = (300, 300)
BATCH_SIZE=32
INIT_LR = 0.001
NUM_EPOCHS=10
CLASSES=10

datagen = ImageDataGenerator(rescale=1./255,
    validation_split=0.1)

train_generator = datagen.flow_from_directory(
    PATH,
    target_size=(HEIGHT, WIDTH),
    color_mode='rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    subset='training')

validation_generator = datagen.flow_from_directory(
    PATH,
    target_size=(HEIGHT, WIDTH),
    color_mode='rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    subset='validation')

counter = Counter(train_generator.classes)
max_val = float(max(counter.values()))
class_weights = {class_id : max_val/num_images for class_id, num_images in counter.items()}

strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    base_model = load_model('base_model.h5')
    base_model.trainable = False

    pool = GlobalAveragePooling2D()
    predictions = Dense(CLASSES, activation='softmax')
    model = Sequential([base_model, pool, predictions])
    model.compile(optimizer=RMSprop(lr=INIT_LR),
                  loss=CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
                  

model.fit(
    train_generator,
    steps_per_epoch = train_generator.samples // BATCH_SIZE,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    epochs=NUM_EPOCHS,
    class_weight=class_weights
)

model.save('model.h5')

Found 23565 images belonging to 10 classes.
Found 2614 images belonging to 10 classes.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
mobilenetv2_1.00_224 (Model) (None, 10, 10, 1280)      2257984   
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                12810     
Total params: 2,270,794
Trainable params: 12,810
Non-trainable params: 2,257,984
_________________________________________________________________
  ...
    to  
  ['...']
  ...
    to  
  ['...']
Train for 736 steps, validate for 81 steps
Epoch 1/10
 31/736 [>.............................] - ETA: 23:23 - loss: 4.0973 - accur

KeyboardInterrupt: 

In [3]:
base_model = MobileNetV2(input_shape=(HEIGHT, WIDTH, 3),
                             include_top=False,
                             weights='imagenet')



In [4]:
base_model.save('base_model.h5')