In [1]:
import tensorflow as tf
import tensornets as nets
import os

In [2]:
CSV_COLUMNS=['emotion', 'pixels','Usage']
CSV_COLUMN_DEFAULTS=[[0],[''],['']]
 
def parse_csv(inputs):
    row_columns=tf.expand_dims(inputs, -1)
    columns=tf.decode_csv(row_columns, 
                          record_defaults=CSV_COLUMN_DEFAULTS, 
                          field_delim=',')
    features=dict(zip(CSV_COLUMNS, columns))
     
    y=features['emotion']
    y=tf.one_hot(y, depth = 7) 
    y=tf.reshape(y, [7])
 
    x=features['pixels']
    x=tf.string_split(x) 
    x=tf.sparse_tensor_to_dense(x, default_value='0') # tf.string_split() outputs a sparse matrix
    x=tf.string_to_number(x)
    x=tf.reshape(x, [48,48,1]) # add channel dimension
    x=tf.image.grayscale_to_rgb(x) # replicate this dimension to get 3 identical channels
    x=tf.image.resize_images(x, [224,224]) # resize image to 224*224
 
    return x, y

In [3]:
train_dataset=tf.data.TextLineDataset("fer2013.csv") \
        .skip(1) \
        .map(parse_csv) \
        .shuffle(10000) \
        .repeat() \
        .batch(32) \
        .prefetch(1)

In [4]:
valid_dataset=tf.data.TextLineDataset("kaggle.csv") \
        .skip(1) \
        .map(parse_csv) \
        .repeat() \
        .batch(32) \
        .prefetch(1)

In [5]:
iterator=tf.data.Iterator.from_structure(train_dataset.output_types, 
                                         train_dataset.output_shapes)
x, y=iterator.get_next()
 
training_init_op=iterator.make_initializer(train_dataset)
validation_init_op=iterator.make_initializer(valid_dataset)

In [6]:
is_train=tf.placeholder_with_default(False, shape=(), name="is_train") # placeholder for is_training

In [7]:
model=nets.VGG19(x, is_training=is_train, classes=7)
train_list=model.get_weights() # get list of weights
loss=tf.losses.softmax_cross_entropy(y, model)
  
accuracy, accuracy_op=tf.metrics.accuracy(tf.argmax(y, 1),tf.argmax(model,1)) # local vars

In [8]:
update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS) # update batch stats during training
with tf.control_dependencies(update_ops):               # only train last 2 fc layer
    train=tf.train.AdamOptimizer(1e-5).minimize(loss, var_list = train_list[34:]) 

In [10]:
load = True
init_op=tf.global_variables_initializer()
local_init_op=tf.local_variables_initializer()
epochs = 100

config=tf.ConfigProto()
config.gpu_options.allow_growth=True

with tf.Session() as sess:
    sess.run(init_op)
    sess.run(local_init_op)
    sess.run(training_init_op)
    sess.run(model.pretrained())
    saver = tf.train.Saver()
    if load:
        saver.restore("./file0" + ".ckpt")
    for epoch in range(epochs):
        for _ in range(897): # 28709 / 32
            sess.run(train, {is_train: True})
            # re-initialize the iterator, but this time with validation data
            sess.run(validation_init_op)
            for _ in range(120): # 
                sess.run(accuracy_op, {is_train: False})
                  
            print("Accumulated validation accuracy is {:.2f}%".format(sess.run(accuracy)*100))
            saver.save(sess, "./file" + "{}.ckpt".format(epoch))

Accumulated validation accuracy is 8.85%
Accumulated validation accuracy is 8.78%
Accumulated validation accuracy is 8.87%
Accumulated validation accuracy is 8.90%
Accumulated validation accuracy is 8.98%
Accumulated validation accuracy is 9.27%
Accumulated validation accuracy is 9.79%
Accumulated validation accuracy is 10.62%
Accumulated validation accuracy is 11.74%
Accumulated validation accuracy is 12.99%
Accumulated validation accuracy is 14.20%
Accumulated validation accuracy is 15.31%
Accumulated validation accuracy is 16.29%
Accumulated validation accuracy is 17.15%
Accumulated validation accuracy is 17.92%
Accumulated validation accuracy is 18.59%
Accumulated validation accuracy is 19.20%
Accumulated validation accuracy is 19.74%
Accumulated validation accuracy is 20.23%
Accumulated validation accuracy is 20.68%
Accumulated validation accuracy is 21.09%
Accumulated validation accuracy is 21.46%
Accumulated validation accuracy is 21.81%
Accumulated validation accuracy is 22.14%

KeyboardInterrupt: 