In [None]:
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import math
import os

In [None]:
def get_normalize(input_shape):
    """
    rescale keypoint distance normalize coefficient
    based on input shape, used for PCK evaluation
    NOTE: 6.4 is standard normalize coefficient under
          input shape (256,256)
    # Arguments
        input_shape: input image shape as (height, width)
    # Returns
        scale: normalize coefficient
    """
    #assert input_shape[0] == input_shape[1], 'only support square input shape.'

    # use averaged scale factor for non square input shape
    scale = float((input_shape[0] + input_shape[1]) / 2) / 256.0

    return 6.4*scale

class EvalCallBack(tf.keras.callbacks.Callback):
    def __init__(self, class_names, model_input_shape):
        self.normalize = get_normalize(model_input_shape)
        self.model_input_shape = model_input_shape
        self.best_acc = 0.0

        self.eval_images = np.load("../images_val_mpii")
        self.eval_hms = np.load("../hms_val_mpii")

    def on_epoch_end(self, epoch, logs=None):
        output = self.model.predict(self.eval_images)
        print(output.shape)
        output = output.reshape( (output.shape[0],)+(16,64,64) )
        val_acc = accuracy(self.model)
        print('validate accuray', val_acc, '@epoch', epoch)

        if val_acc > self.best_acc:
            # Save best accuray value and model checkpoint
            checkpoint_dir = os.path.join(self.log_dir, 'ep{epoch:03d}-loss{loss:.3f}-val_acc{val_acc:.3f}.h5'.format(epoch=(epoch+1), loss=logs.get('loss'), val_acc=val_acc))
            self.model.save(checkpoint_dir)
            print('Epoch {epoch:03d}: val_acc improved from {best_acc:.3f} to {val_acc:.3f}, saving model to {checkpoint_dir}'.format(epoch=epoch+1, best_acc=self.best_acc, val_acc=val_acc, checkpoint_dir=checkpoint_dir))
            self.best_acc = val_acc
        else:
            print('Epoch {epoch:03d}: val_acc did not improve from {best_acc:.3f}'.format(epoch=epoch+1, best_acc=self.best_acc))