In [74]:
# !pip install --upgrade tfds-nightly

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import os

In [76]:
labels_list = [
    'одуванчик',
    'маргаритка'
    'тюльпаны',
    'подсолнухи',
    'розы',
    ]

In [55]:
test_set, train_set = tfds.load(
                                'tf_flowers',
                                split = ["train[0%:20%]", "train[20%:]"], 
                                as_supervised = True,
                                batch_size = 32,
                                # with_info = True
                                )
classes_number = 5

In [58]:
from tensorflow.keras.layers import RandomRotation, RandomFlip,\
Resizing, Rescaling, RandomContrast, RandomZoom, RandomCrop

IMG_SIZE = 224

resize_and_rescale = tf.keras.Sequential([
  Resizing(IMG_SIZE, IMG_SIZE),
  Rescaling(1./255)
])

data_augmentation = tf.keras.Sequential([                   
  RandomFlip("horizontal_and_vertical"),
  # RandomCrop(IMG_SIZE//2, IMG_SIZE//2),
  RandomRotation(0.2),
  RandomContrast(0.2),
  # RandomZoom(0.5)
])

augmentation=tf.keras.Sequential([
  resize_and_rescale,
  data_augmentation
])



In [60]:
train_set = train_set.map(lambda x, y: (augmentation(x, training=True), 
                                        tf.one_hot(y, classes_number)))
valid_set = valid_set.map(lambda x, y: (resize_and_rescale(x), tf.one_hot(y, classes_number)))

















In [64]:
from tensorflow.keras.applications.mobilenet import MobileNet
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, MaxPooling2D


inputs = Input((IMG_SIZE, IMG_SIZE,3))

model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top = False,
                  weights = 'imagenet')


model.trainable = False
x=model(inputs)
x = MaxPooling2D()(x)
global_average_layer = GlobalAveragePooling2D()
x = global_average_layer(x)
outputs = Dense(classes_number, activation = 'softmax')(x)

model = tf.keras.models.Model(inputs, outputs)

model.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_1_0_224_tf_no_top.h5
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 mobilenet_1.00_224 (Functio  (None, 7, 7, 1024)       3228864   
 nal)                                                            
                                                                 
 max_pooling2d (MaxPooling2D  (None, 3, 3, 1024)       0         
 )                                                               
                                                                 
 global_average_pooling2d (G  (None, 1024)             0         
 lobalAveragePooling2D)                                          
                                                                 
 de

In [65]:
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['acc'])

In [66]:
modelname = 'model'
models_dir = 'models'
os.makedirs(models_dir, exist_ok = True)
best_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = os.path.join('models', f'{modelname}_best.hdf5'),
                                   monitor = 'val_loss',
                                   save_best_only = True,
                                   mode = 'min')
last_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath = os.path.join('models', f'{modelname}_last.hdf5'),
                                   monitor = 'val_loss',
                                   save_best_only = False,
                                   mode = 'auto')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
callbacks = [best_checkpoint, last_checkpoint, early_stopping]

In [67]:
epochs=30

history = model.fit(train_set,
                    epochs=epochs,
                    validation_data=valid_set,
                    callbacks = callbacks,
                    )

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
