In [3]:
# Mute tensorflow debugging information console
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [1]:
from keras.layers import Conv2D, MaxPooling2D, Convolution2D, Dropout, Dense, Flatten, LSTM
from keras.models import Sequential, save_model
from keras.utils import np_utils
from scipy.io import loadmat

import keras

import numpy as np

import pickle

import tensorflow as tf


In [2]:
def train(model, training_data, callback=False, batch_size=256, epochs=10):
    class TFCheckpointCallback(keras.callbacks.Callback):
        def __init__(self, saver, sess):
            self.saver = saver
            self.sess = sess

        def on_epoch_end(self, epoch, logs=None):
            self.saver.save(self.sess, 'freeze/checkpoint', global_step=epoch)

    (x_train, y_train), (x_test, y_test), mapping, nb_classes = training_data
    
    # convert class vectors to binary class matrices
    y_train = np_utils.to_categorical(y_train, nb_classes)
    y_test = np_utils.to_categorical(y_test, nb_classes)
    
    # Callback for analysis in TensorBoard
    tbCallBack = keras.callbacks.TensorBoard(
        log_dir='./Graph', 
        histogram_freq=0, 
        write_graph=True, 
        write_images=True
    ) if callback else None
    
    sess = tf.Session()
    keras.backend.set_session(sess)
    
    tf_graph = sess.graph
    # Ref: https://www.tensorflow.org/api_docs/python/tf/train/Saver
    tf_saver = tf.train.Saver()
    tfckptcb = TFCheckpointCallback(tf_saver, sess)

    # Write the protobuf graph
    # ref: https://www.tensorflow.org/api_docs/python/tf/train/write_graph
    tf.train.write_graph(tf_graph.as_graph_def(),
                         'freeze', 'graph.pbtxt', as_text=True)
    tf.train.write_graph(tf_graph.as_graph_def(),
                         'freeze', 'graph.pb', as_text=False)

    model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test),
          callbacks=[tbCallBack, tfckptcb] if callback else [tfckptcb])
    
    score = model.evaluate(x_test, y_test, verbose=0)
    
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])

    # Offload model to file
    model_yaml = model.to_yaml()
    with open("bin/model.yaml", "w") as yaml_file:
        yaml_file.write(model_yaml)
    save_model(model, 'bin/model.h5')