In [8]:
import tools
import os
import tensorflow as tf
import models

__logger = tools.get_logger(__name__, do_file_logging=False)
__checkpoint_file_prefix = "wrist_est"
data_dir = "E:\\MasterDaten\\Results\\wrist_cnn"
checkpoint_dir = os.path.join(data_dir, "checkpoints")
saved_model_dir = os.path.join(data_dir, "saved_models")
log_dir = os.path.join(data_dir, "logs")
tensorboard_dir = os.path.join(log_dir, "tensorboard")
        

In [9]:
def try_load_checkpoint(model, checkpoint_dir):
    if os.path.exists(checkpoint_dir):
        cp_files = [os.path.abspath(os.path.join(checkpoint_dir, filename)) for filename in os.listdir(checkpoint_dir)]
        cp_files = [path for path in cp_files if os.path.isfile(path) and __checkpoint_file_prefix in os.path.basename(path)]

        if len(cp_files) > 0:
            files_sorted = sorted(cp_files, key=os.path.getctime, reverse=True)
            for latest_file in files_sorted:
                try:
                    __logger.info("Trying to load weights from {}".format(latest_file))
                    model.load_weights(latest_file)
                    __logger.info("Loading successful!")
                    return model
                except Exception as e:
                    __logger.exception(e)
                    __logger.error("Loading of {} failed!".format(latest_file))

def save(model, save_name=None):
    if not os.path.exists(saved_model_dir):
        os.makedirs(saved_model_dir)

    if save_name:
        tf.saved_model.save(model, os.path.join(saved_model_dir, save_name))
    else:
        highest_index = max([d for d in os.listdir(saved_model_dir) if not os.path.isfile(os.path.join(saved_model_dir, d) and d.isnumeric())])
        tf.saved_model.save(model, os.path.join(saved_model_dir, str(highest_index + 1)))
    return model

def try_load_saved(model, save_name=None):
    if not os.path.exists(saved_model_dir):
        return False

    if save_name:
        model = tf.saved_model.load(os.path.join(saved_model_dir, save_name))
    else:
        highest_index = max([d for d in os.listdir(saved_model_dir) if not os.path.isfile(os.path.join(saved_model_dir, d) and d.isnumeric())])
        model = tf.saved_model.save(os.path.join(saved_model_dir, str(highest_index + 1)))
    return model

In [10]:
model = models.WristCNN()



KeyboardInterrupt



In [None]:
def prepare_ds(ds):
    ds = ds.map(lambda img, skel: (
            tf.clip_by_value(tf.cast(tf.image.resize(img, tf.constant([227, 227], dtype=tf.dtypes.int32)),
                                     dtype=tf.float32) / tf.constant(2500.0, dtype=tf.float32),
                             clip_value_min=0.0,
                             clip_value_max=1.0),
            skel[:3]), num_parallel_calls=tf.data.experimental.AUTOTUNE)  # ignore stuff more than 2.5m away.
    ds = ds.map(
            lambda img, skel: (
                    tools.colorize(img, 0.0, 1.0, "viridis"),
                    skel
            )
    )
    return ds


In [None]:
dataset_train = tools.NYU.get_dataset("G:\\master_thesis_data\\Datasets\\nyu\\nyu_hand_dataset_v2\\dataset", "train")
dataset_validation = tools.NYU.get_dataset("G:\\master_thesis_data\\Datasets\\nyu\\nyu_hand_dataset_v2\\dataset", "validation")

dataset_train = prepare_ds(dataset_train)
dataset_validation = prepare_ds(dataset_validation)

dataset_train = dataset_train.shuffle(10 * 10)
dataset_train = dataset_train.batch(batch_size=25).prefetch(tf.data.experimental.AUTOTUNE)

dataset_validation = dataset_validation.batch(batch_size=25).prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
learning_rate = 0.0005
tools.clean_tensorboard_logs(tensorboard_dir)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, clipvalue=10)
model.compile(optimizer=optimizer, loss=tf.keras.losses.mean_squared_error, metrics=["mae", "acc"])

if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir)
if not os.path.exists(tensorboard_dir): os.makedirs(tensorboard_dir)

checkpointer = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(checkpoint_dir, __checkpoint_file_prefix + "weights.{epoch:02d}.hdf5"),
        save_best_only=False)

tensorboard = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=0,
                                             write_graph=True, write_images=True, update_freq='batch', profile_batch=0)

In [None]:
model.fit(
        dataset_train,
        validation_data=dataset_validation,
        epochs=300,
        verbose=2,
        callbacks=[checkpointer, tensorboard])