This is an example pipeline to train a softmax model on to the 30 species

In [None]:
from src import ds_generator
import tensorflow as tf
import util
import datetime
import os

In [None]:
# import data
train_ds, val_ds = ds_generator.DS_Generator().generate_species_data(util.TRAIN_SPECIES_DF,
                                                                     augment=1,
                                                                     batch_size=64,
                                                                     seed=util.SPECIES_SEED)

In [None]:
# build and ccompile a model
Input = tf.keras.Input((224,224,3))
# Choose some keras application
base = tf.keras.applications.inception_v3.InceptionV3(
    include_top=False,
    weights='imagenet',
    input_tensor=Input,
    input_shape=None,
    pooling="max",
    classifier_activation='softmax'
)
head = tf.keras.layers.Dense(util.NUMBER_OF_SPECIES,activation="softmax",)(base.output)
model = tf.keras.Model(inputs=Input, outputs=head,name="SomeNiceName")
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss="categorical_crossentropy",metrics=["acc"])

In [None]:
# This part is to create for the logs + model checkpoints callbacks

if model.name not in os.listdir(util.SAVING_PATH):
    os.makedirs(util.SAVING_PATH+model.name)
    os.makedirs(util.SAVING_PATH+model.name+"/logs")
    os.makedirs(util.SAVING_PATH+model.name+"/saves")

# timestamp for logging
time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# paths
log_dir = util.SAVING_PATH+model.name+"/logs/" +time_stamp
checkpoint_path = util.SAVING_PATH +model.name+"/saves/" +time_stamp +"/cp-{epoch:04d}.ckpt"

# callback for logs
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# callback model checkpoints
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=True)

In [None]:
# fit model
model.fit(
train_ds,
epochs=35,
validation_data=val_ds,
callbacks=[cp_callback,tensorboard_callback])