# eye pose estimator

<br>

## baseline 모델 
#### - mobilenet_v2_1.0_128

<br>

## experiments
#### - sobel 필터 적용 유무
#### - classification vs regression
### - head pose perspective warping

# imports

In [1]:
import sys
import keras
import keras.backend as K

from keras.losses import sparse_categorical_crossentropy
from keras.metrics import sparse_top_k_categorical_accuracy
from keras.applications.mobilenet_v2 import MobileNetV2

Using TensorFlow backend.


In [2]:
sys.path.append("../../../")

%run ../../../ds/unity/npz/gen.py
%run ../../../ac/visualizer/plotter.py
%run ../../../ai/libs/keras/callbacks/history.py
%run ../../../ai/libs/keras/callbacks/checkpoint.py
%run ../../../ai/libs/keras/callbacks/stopper.py

# resource paths

In [3]:
npz_path = "/home/chy/archive-data/processed/unity-class-shuffled-npz"
chk_path = "/home/chy/archive-model/incubator/eye-pose/ep-{epoch:02d}-{val_loss:.7f}.hdf5"

# model

In [4]:
INPUT_SHAPE=(56,112,1)

In [5]:
epm = MobileNetV2(input_shape=INPUT_SHAPE, 
                    alpha=1.0, 
                    depth_multiplier=1, 
                    include_top=True, 
                    weights=None, 
                    input_tensor=None, 
                    pooling=None, 
                    classes=1681)

epm.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 56, 112, 1)   0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 57, 113, 1)   0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 28, 56, 32)   288         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 28, 56, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu

# callback

In [6]:
history = TimeHistory()
checkpoint = model_checkpoint(chk_path)
stopper = EarlyStopping(monitor='val_loss', 
                        min_delta=0, 
                        patience=5, 
                        verbose=0, 
                        mode='auto', 
                        baseline=None, 
                        restore_best_weights=True)

callbacks = [history, checkpoint, stopper]

# compile

In [7]:
def top10_accuracy(y_true, y_pred):
    return sparse_top_k_categorical_accuracy(y_true, y_pred, k=10)

epm.compile(loss=sparse_categorical_crossentropy, 
            optimizer='adam', 
            metrics=['accuracy', top10_accuracy])

# hyper-params

In [8]:
EXP_CODE = "pilot"
NUM_EPOCH = 20
BATCH_SIZE = 250

# model-meta back up

In [9]:
model_json = epm.to_json()
with open("./epm-{}.json".format(EXP_CODE), "w") as json_file : 
    json_file.write(model_json)

# train

In [10]:
gen_train = UnityEyePoseGenerator(npz_base_path=npz_path, 
                                  batch_size=BATCH_SIZE, 
                                  purpose=Purpose.TRAIN,
                                  use_postprocess=True,
                                  use_aug=True)

gen_valid = UnityEyePoseGenerator(npz_base_path=npz_path, 
                                  batch_size=BATCH_SIZE, 
                                  purpose=Purpose.VALID,
                                  use_postprocess=True,
                                  use_aug=True)

*** meta verify complete [Unknown] ***
*** meta verify complete [Unknown] ***


In [None]:
epm.fit_generator(generator=gen_train,
                  validation_data=gen_valid,
                  callbacks=callbacks, 
                  workers=16, 
                  use_multiprocessing=True,
                  shuffle=True)

Epoch 1/1
 2907/23525 [==>...........................] - ETA: 4:49:26 - loss: 7.4536 - acc: 0.0010 - top10_accuracy: 0.0101