In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import datasets
from tensorflow.keras import callbacks
from tensorflow.keras.datasets import mnist
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications import EfficientNetB7
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.applications import EfficientNetB0

import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

import os

import tempfile
from os import path
%load_ext tensorboard

def normalize(image, label):
    return tf.cast(image, tf.float32) / 255., label

# Global settings

In [2]:
# todo change this to 100+

NUM_EPOCHS = 300

### Train and save model 

In [3]:
def train_model(model, ds_train, ds_validation, model_name, batch_size=64):
#    if path.exists(model_path):
#        print("Model is already trained and saved here: " + model_path)
#        return
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0003),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    
    
    save = callbacks.ModelCheckpoint(
        os.path.join('tmp', model_name + "_e_{epoch:02d}.h5"),
        monitor='loss',
        verbose=1,
        save_best_only=False,
        save_weights_only=False,
        mode='auto')

    early = callbacks.EarlyStopping(monitor='val_loss',
                                    min_delta=0,
                                    patience=30,
                                    verbose=1,
                                    mode='auto')

    hist = model.fit(ds_train, 
                     epochs=NUM_EPOCHS, 
                     validation_data=ds_validation,
                     callbacks=[save]
                    )
    
    # save model
    model.save(model_name + ".h5")
    print('Saved to: ' + model_name + ".h5")

# Image classification models

In [4]:
def convert_ds_to_tensors(ds):
    """returns tuple of train_X, train_y"""
    a = ds.map(lambda a, b: a)
    tf_list = []
    for i in a:
        tf_list.append(i)
    train_X = tf.stack(tf_list, axis=0)
    
    b = ds.map(lambda a, b: b)
    tf_list = []
    for i in b:
        tf_list.append([i.numpy()])
    train_y = np.array(tf_list, dtype=np.uint8)
    return train_X, train_y

## Beans dataset

https://www.tensorflow.org/datasets/catalog/beans

In [9]:
def preprocess_beans_train(image, label):
    image = tf.image.random_flip_left_right(image)    
    return image, label

def preprocess_beans_test_and_val(image, label):
    return image, label

### wrap model for beans

In [10]:
def wrap_model_for_beans(base_model, num_classes):
    inputs = base_model.inputs
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes)(x)
    model = tf.keras.Model(inputs, outputs)
    return model

In [11]:
NUM_CLASSES = 3
INPUT_SHAPE = (500, 500, 3)
BATCH_SIZE = 32

In [12]:
def load_beans_datasets():
    (ds_train, ds_validation, ds_test), ds_info = tfds.load(
        'beans',
        split=['train', 'validation', 'test'],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )
    
    ds_train = ds_train.map(normalize)
    #ds_train = ds_train.map(preprocess_beans_train)
    ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples).batch(BATCH_SIZE)
    
    ds_validation = ds_validation.map(normalize)
    #ds_validation = ds_validation.map(preprocess_beans_test_and_val)
    ds_validation = ds_validation.batch(BATCH_SIZE)
    
    ds_test = ds_test.map(normalize)
    #ds_test = ds_test.map(preprocess_beans_test_and_val)
    ds_test = ds_test.batch(BATCH_SIZE)
    
    
    return ds_train, ds_validation, ds_test

In [None]:
ds_train, ds_validation, ds_test = load_beans_datasets()


### mobilenetV2

In [13]:
MODEL_NAME = 'MobileNetV2_beans_model'

In [None]:
base_model = MobileNetV2(include_top=False, weights=None, input_shape=INPUT_SHAPE)

model = wrap_model_for_beans(base_model=base_model, num_classes=NUM_CLASSES)

train_model(model=model, ds_train=ds_train, ds_validation=ds_validation, model_name=MODEL_NAME)

In [None]:
model.evaluate(ds_test)

### EfficientNets - B0, B4

#### EfficientNetB0

In [None]:
MODEL_PATH = "EfficentNetB0_beans_model"

In [None]:
base_model = EfficientNetB0(include_top=False, weights=None, input_shape=INPUT_SHAPE)

model = wrap_model_for_beans(base_model=base_model, num_classes=NUM_CLASSES)

train_model(model=model, ds_train=ds_train, ds_validation=ds_validation, model_name=MODEL_PATH)

#### EfficientNetB4

In [None]:
MODEL_PATH = 'EfficentNetB4_beans_model'

In [None]:
base_model = EfficientNetB4(include_top=False, weights=None, input_shape=INPUT_SHAPE)

model = wrap_model_for_beans(base_model=base_model, num_classes=NUM_CLASSES)

train_model(model=model, ds_train=ds_train, ds_validation=ds_validation, model_name=MODEL_PATH)

## flowers dataset

https://www.tensorflow.org/datasets/catalog/oxford_flowers102

In [14]:
def random_crop(image):
    cropped_image = tf.image.random_crop(
        image, size=[256, 256, 3])

    return cropped_image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def preprocess_flowers_train(image, label):
    image = random_jitter(image)
    return image, label

# -------------------------------

def preprocess_flowers(image, label):
    image = tf.image.resize(image, [256, 256],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return image, label

### wrap model for flowers

In [15]:
def wrap_model_for_flowers(base_model, num_classes):
    inputs = base_model.inputs
    x = base_model.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes)(x)
    model = tf.keras.Model(inputs, outputs)
    return model

In [16]:
NUM_CLASSES = 102
INPUT_SHAPE = (256, 256, 3)
BATCH_SIZE = 32
RESIZE_DIMENSION = 256

In [17]:
def load_flowers_dataset():  
    (ds_train, ds_validation, ds_test), ds_info = tfds.load(name="oxford_flowers102", 
                                             with_info=True,
                                             split=['train', 'validation', 'test'],  #70/15/15 split
                                             as_supervised=True)

    ds_train = ds_train.map(normalize, 
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)    
    ds_train = ds_train.map(preprocess_flowers)
    ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    ds_train = ds_train.batch(BATCH_SIZE)
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
    
    ds_validation = ds_validation.map(normalize, 
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_validation = ds_validation.map(preprocess_flowers)
    ds_validation = ds_validation.batch(BATCH_SIZE)
    ds_validation = ds_validation.prefetch(tf.data.experimental.AUTOTUNE)
    
    ds_test = ds_test.map(normalize, 
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_test = ds_test.map(preprocess_flowers)
    ds_test = ds_test.batch(BATCH_SIZE)
    ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
    
    return ds_train, ds_validation, ds_test

In [18]:
ds_train, ds_validation, ds_test = load_flowers_dataset()

### mobilenetV2

In [19]:
MODEL_NAME = 'MobileNetV2_flowers_model'

In [20]:
base_model = MobileNetV2(include_top=False, weights=None, input_shape=INPUT_SHAPE)

model = wrap_model_for_flowers(base_model=base_model, num_classes=NUM_CLASSES)

train_model(model=model, ds_train=ds_train, ds_validation=ds_validation, model_name=MODEL_NAME)

Epoch 1/300

Epoch 00001: saving model to tmp/MobileNetV2_flowers_model_e_01.h5
Epoch 2/300

Epoch 00002: saving model to tmp/MobileNetV2_flowers_model_e_02.h5
Epoch 3/300

Epoch 00003: saving model to tmp/MobileNetV2_flowers_model_e_03.h5
Epoch 4/300

Epoch 00004: saving model to tmp/MobileNetV2_flowers_model_e_04.h5
Epoch 5/300

Epoch 00005: saving model to tmp/MobileNetV2_flowers_model_e_05.h5
Epoch 6/300

Epoch 00006: saving model to tmp/MobileNetV2_flowers_model_e_06.h5
Epoch 7/300

Epoch 00007: saving model to tmp/MobileNetV2_flowers_model_e_07.h5
Epoch 8/300

Epoch 00008: saving model to tmp/MobileNetV2_flowers_model_e_08.h5
Epoch 9/300

Epoch 00009: saving model to tmp/MobileNetV2_flowers_model_e_09.h5
Epoch 10/300

Epoch 00010: saving model to tmp/MobileNetV2_flowers_model_e_10.h5
Epoch 11/300

Epoch 00011: saving model to tmp/MobileNetV2_flowers_model_e_11.h5
Epoch 12/300

Epoch 00012: saving model to tmp/MobileNetV2_flowers_model_e_12.h5
Epoch 13/300

Epoch 00013: saving mod

### EfficientNets - B0, B4

#### EfficientNetB0

In [21]:
MODEL_NAME = 'EfficentNetB0_flowers_model'

In [22]:
base_model = EfficientNetB0(include_top=False, weights=None, input_shape=INPUT_SHAPE)

model = wrap_model_for_flowers(base_model=base_model, num_classes=NUM_CLASSES)

train_model(model=model, ds_train=ds_train, ds_validation=ds_validation, model_name=MODEL_NAME)

Epoch 1/300

Epoch 00001: saving model to tmp/EfficentNetB0_flowers_model_e_01.h5
Epoch 2/300

Epoch 00002: saving model to tmp/EfficentNetB0_flowers_model_e_02.h5
Epoch 3/300

Epoch 00003: saving model to tmp/EfficentNetB0_flowers_model_e_03.h5
Epoch 4/300

Epoch 00004: saving model to tmp/EfficentNetB0_flowers_model_e_04.h5
Epoch 5/300

Epoch 00005: saving model to tmp/EfficentNetB0_flowers_model_e_05.h5
Epoch 6/300

Epoch 00006: saving model to tmp/EfficentNetB0_flowers_model_e_06.h5
Epoch 7/300

Epoch 00007: saving model to tmp/EfficentNetB0_flowers_model_e_07.h5
Epoch 8/300

Epoch 00008: saving model to tmp/EfficentNetB0_flowers_model_e_08.h5
Epoch 9/300

Epoch 00009: saving model to tmp/EfficentNetB0_flowers_model_e_09.h5
Epoch 10/300

Epoch 00010: saving model to tmp/EfficentNetB0_flowers_model_e_10.h5
Epoch 11/300

Epoch 00011: saving model to tmp/EfficentNetB0_flowers_model_e_11.h5
Epoch 12/300

Epoch 00012: saving model to tmp/EfficentNetB0_flowers_model_e_12.h5
Epoch 13/300


#### EfficientNetB4

In [25]:
MODEL_PATH = './flowers_models/EfficentNetB4_flowers_model.h5'

In [26]:
base_model = EfficientNetB4(include_top=False, weights=None, input_shape=INPUT_SHAPE)

model = wrap_model_for_beans(base_model=base_model, num_classes=NUM_CLASSES)

train_model(model=model, ds_train=ds_train, ds_validation=ds_validation, model_name=MODEL_PATH)

Epoch 1/300

Epoch 00001: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_01.h5
Epoch 2/300

Epoch 00002: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_02.h5
Epoch 3/300

Epoch 00003: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_03.h5
Epoch 4/300

Epoch 00004: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_04.h5
Epoch 5/300

Epoch 00005: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_05.h5
Epoch 6/300

Epoch 00006: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_06.h5
Epoch 7/300

Epoch 00007: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_07.h5
Epoch 8/300

Epoch 00008: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_08.h5
Epoch 9/300

Epoch 00009: saving model to tmp/./flowers_models/EfficentNetB4_flowers_model.h5_e_09.h5
Epoch 10/300

Epoch 00010: saving model to tmp/./flowers_models/EfficentNetB4_flow