In [None]:
base_dir = './dataset/train' #setting the base_dir variable to the location of the dataset containing the images


In [None]:
# create datasets for training, validation, and testing
train_fldr = '../sample/train'
val_fldr = '../sample/val'
test_fldr = '../sample/test'

train_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(
        train_fldr,
        target_size = (256, 256),
        batch_size = 16,
        class_mode = 'binary',
        seed = 42)
valid_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(
        val_fldr, 
        target_size = (256, 256),
        batch_size = 16,
        class_mode = 'binary',
        seed = 42)
test_generator = ImageDataGenerator(rescale=1./255).flow_from_directory(
        test_fldr, 
        target_size = (256, 256),
        batch_size = 1,
        class_mode = 'binary',
        shuffle = False,
        seed = 42)

STEP_SIZE_TRAIN = train_generator.n // train_generator.batch_size
STEP_SIZE_VALID = valid_generator.n // valid_generator.batch_size
STEP_SIZE_TEST = test_generator.n // test_generator.batch_size

In [None]:
# create labels.txt file that will hold all our labels for Flutter
print(train_generator.class_indices) 
labels = '\n'.join(sorted(train_generator.class_indices.keys())) 
with open('labels.txt', 'w') as f: 
    f.write(labels)


In [None]:

base_model = MobileNetV2(
    input_shape= (224, 224, 3),
    include_top = False, 
    weights='imagenet'
)

In [None]:
base_model.trainable=False 
model = tf.keras.Sequential([ 
    base_model,
    tf.keras.layers.Conv2D(32,3, activation = 'relu'), 
    tf.keras.layers.Dropout(0.2), 
    tf.keras.layers.GlobalAveragePooling2D(), 
    tf.keras.layers.Dense(36, activation = 'softmax')
])

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)

In [None]:
history = model.fit(
    train_generator, 
    epochs = epochs, 
    validation_data=val_generator
)

In [None]:
saved_model_dir = '' 
tf.saved_model.save(model, saved_model_dir) 

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) 
tflite_model = converter.convert()

with open('model.tflite', 'wb') as f: 
  f.write(tflite_model)