In [None]:
import tensorflow as tf
import os

In [6]:
_URL = 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'
zip_file = tf.keras.utils.get_file(origin=_URL, fname='flower_photos.tgz', extract=True, cache_subdir='/content') 

base_dir = os.path.join(os.path.dirname(zip_file), 'flower_photos')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


In [20]:
IMAGE_SIZE = 224
BATCH_SIZE = 64

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 1./255,
    validation_split=0.2
)
train_generator = datagen.flow_from_directory(
    base_dir, 
    target_size=(IMAGE_SIZE,IMAGE_SIZE), 
    batch_size=BATCH_SIZE, 
    subset='training'
    )

val_generator = datagen.flow_from_directory(
    base_dir, 
    target_size=(IMAGE_SIZE,IMAGE_SIZE), 
    batch_size=BATCH_SIZE, 
    subset='validation'
    
)

Found 2939 images belonging to 5 classes.
Found 731 images belonging to 5 classes.


In [21]:
print(train_generator.class_indices)
labels ='\n'.join(sorted(train_generator.class_indices.keys()))

with open('labels.txt', 'w') as f:
  f.write(labels)

{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}


In [22]:
IMG_SHAPE = (IMAGE_SIZE,IMAGE_SIZE,3)

base_model = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SHAPE,
    include_top=False,
    weights= 'imagenet'   
)

In [23]:
base_model.trainable = False

In [24]:
model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.Conv2D(32,3, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(5, activation='softmax')
])

In [27]:
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

In [28]:
epochs =10
history = model.fit(train_generator,
                    epochs= epochs,
                    validation_data= val_generator) 

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [30]:
tf.saved_model.save(model,save_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(save_model_dir)

tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
  f.write(tflite_model)




FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.



FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.


INFO:tensorflow:Assets written to: assets


INFO:tensorflow:Assets written to: assets


In [32]:
from google.colab import files
files.download('model.tflite')
files.download('labels.txt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>