In [2]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D

In [3]:
import tensorflow_datasets as tfds
(train_ds, val_ds), ds_info = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:]'],
    as_supervised=True,
    with_info=True
)

In [4]:
IMG_SIZE = 160
def preprocess(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) #resize all images to 160x160
    image = image / 255.0 #normalize to [0,1] range
    return image, label

In [5]:
train_ds = train_ds.map(preprocess).batch(32).shuffle(1000)
val_ds = val_ds.map(preprocess).batch(32)

In [6]:
#frozen_layers
base_model = MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3),
                        include_top = False,
                        weights = 'imagenet')
base_model.trainable = False #freeze the base model

In [None]:
x = base_model.output
x = GlobalAveragePooling2D()(x) #add a global spatial average pooling layer
x = Dense(128, activation='relu')(x) #add a fully-connected layer   
predictions = Dense(5, activation='softmax')(x) #add a logistic layer with 5 classes

In [8]:
model = Model(inputs=base_model.input, outputs=predictions) #create the final model

In [11]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy']) 
model.fit(train_ds, validation_data=val_ds, epochs=5) #train the model

Epoch 1/5
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 567ms/step - accuracy: 0.7125 - loss: 0.7635 - val_accuracy: 0.8678 - val_loss: 0.3734
Epoch 2/5
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 537ms/step - accuracy: 0.9315 - loss: 0.2057 - val_accuracy: 0.8978 - val_loss: 0.3082
Epoch 3/5
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 508ms/step - accuracy: 0.9674 - loss: 0.1295 - val_accuracy: 0.8896 - val_loss: 0.3265
Epoch 4/5
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 509ms/step - accuracy: 0.9802 - loss: 0.0805 - val_accuracy: 0.9019 - val_loss: 0.3190
Epoch 5/5
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 514ms/step - accuracy: 0.9916 - loss: 0.0490 - val_accuracy: 0.9060 - val_loss: 0.3195


<keras.src.callbacks.history.History at 0x23668213a40>