In [None]:
'''Trains a simple convnet on the face landmark dataset.
Adapted from Keras MNIST CNN example code.
'''

import numpy as np
np.random.seed(1337)  # for reproducibility

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.utils import np_utils

batch_size = 128
nb_landmarks = 5
nb_epoch = 50

# input image dimensions
img_chns, img_rows, img_cols = 3, 16, 16
# number of convolutional filters to use
nb_filters = [32, 64]
# size of pooling area for max pooling
nb_pool_sizes = [(2, 2), (2, 2)]
# convolution kernel size
nb_conv = 3
# number of fully connected neurons in the penultimate layer
nb_penu_neurons = 128
# size of output vector, two coordinates for each landmark
nb_output_size = nb_landmarks * 2


model = Sequential()

model.add(Convolution2D(nb_filters[0], nb_conv, nb_conv,
                        border_mode='valid',
                        input_shape=(img_chns, img_rows, img_cols)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=nb_pool_sizes[0]))
model.add(Convolution2D(nb_filters[1], nb_conv, nb_conv))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=nb_pool_sizes[1]))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(nb_penu_neurons))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_output_size))

model.compile(loss='mse',
              optimizer='adadelta',
              metrics=['accuracy'])

model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
          verbose=1, validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])

In [None]:
Y_pred = model.predict(X_test)

In [None]:
# observe learning results
plt.figure()
for i in range(len(Y_test)) :
    plt.subplot(5, 6, i+1)
    img = X_test[i].transpose((1,2,0))
    io.imshow(img)
    pts2 = Y_test[i].reshape((5,2))
    plt.plot(pts2[:,0], pts2[:,1], 'o')
    pts = Y_pred[i].reshape((5,2))
    plt.plot(pts[:,0], pts[:,1], 'ro')
    plt.axis('off')
io.show()

In [None]:
# prepare data for CNN
from skimage import io
import numpy as np

# read data, 330 instances
dirname = "../testset/helen/testset_16/"
l = list(dirname+f for f in get_froot_list(dirname))

Y = np.array([read_pts(frootname+".pts").flatten() for frootname in l])
X = np.array([np.transpose(io.imread(frootname+".jpg"), (2,0,1)) for frootname in l])
X = X.astype('float32')
Y = Y.astype('float32')
X /= 255

# split between train and test sets
(X_train, Y_train) = X[:300], Y[:300]
(X_test, Y_test) = X[300:], Y[300:]

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

In [None]:
# observe data
from skimage import data, io, filters
import matplotlib.pyplot as plt
import numpy as np

frootname = "../testset/helen/testset_16/"+"2973812451_1_16"
pts = read_pts(frootname+".pts")
image = io.imread(frootname+".jpg")
io.imshow(image)
plt.plot(pts[:,0], pts[:,1], 'o')
for i, (x, y) in enumerate(pts) :
    plt.text(x, y, str(int(i)))
io.show()

In [None]:
# helpers for reading data
def get_froot_list(dirname) :
    import os
    f = []
    for (dirpath, dirnames, filenames) in os.walk(dirname):
        f.extend(filenames)
        break
    return sorted(filter(lambda x : len(x)>0, set(fn.split(os.path.extsep)[0] for fn in f)))

def read_pts(fname) :
    pts = np.array([[float(x) for x in line.split()] for line in open(fname)])
    pts = np.array([
        sum(pts[i] for i in (37,38,40,41)) / 4,
        sum(pts[i] for i in (43,44,46,47)) / 4,
        pts[30],
        pts[48],
        pts[54],
    ])[:,:2]
    return pts

In [None]:
plt.figure()
for i in range(len(Y_test)) :
    plt.subplot(5, 6, i+1)
    img = X_test[i].transpose((1,2,0))
    io.imshow(img)
    pts2 = Y_test[i].reshape((5,2))
    plt.plot(pts2[:,0], pts2[:,1], 'o')
    pts = Y_pred[i].reshape((5,2))
    plt.plot(pts[:,0], pts[:,1], 'ro')
    plt.axis('off')
io.show()

In [None]:
Y_pred