In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.python.tools import freeze_graph
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.python.platform import gfile
from google.protobuf import text_format

In [5]:
NUM_STEP = 2000
BATCH_SIZE = 100
tf.reset_default_graph()

def model_input(input_node_name,keep_prob_node_name):
    x = tf.placeholder(tf.float32,shape=[None,28,28,1],name=input_node_name)
    keep_prob = tf.placeholder(tf.float32,name=keep_prob_node_name)
    y_ = tf.placeholder(tf.float32,shape=[None,10])
    return x,keep_prob,y_

def build_model(x,keep_prob,y_,output_node_name):
    x_image = tf.reshape(x,[-1,28,28,1])
    
    #conv1 = 28X28X32
    conv1 = tf.layers.conv2d(
    inputs=x_image,
    filters=32,
    kernel_size=[5,5],
    padding='same',
    activation=tf.nn.relu) 
    #pad1 = 14X14X32
    pad1 = tf.layers.max_pooling2d(inputs=conv1,pool_size=[2,2],strides=2)
    
    #conv2 = 14X14X64
    conv2 = tf.layers.conv2d(
    inputs=pad1,
    filters=64,
    kernel_size=[5,5],
    padding='same',
    activation=tf.nn.relu)
    #pad2 = 7X7X64
    pad2 = tf.layers.max_pooling2d(inputs=conv2,pool_size=[2,2],strides=2)
    
    pad2flat = tf.reshape(pad2,[-1,7*7*64])
    
    dense = tf.layers.dense(inputs=pad2flat,units=1024,activation=tf.nn.relu)
    
    dropout = tf.layers.dropout(inputs=dense,rate=keep_prob)
    
    logits = tf.layers.dense(inputs=dropout,units=10)
    outputs = tf.nn.softmax(logits,name=output_node_name)
    
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=logits))
    
    train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
    
    correct_prediction = tf.equal(tf.argmax(outputs,1),tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    tf.summary.scalar("loss", loss)
    tf.summary.scalar("accuracy", accuracy)
    merged_summary_op = tf.summary.merge_all()
    
    return train_step,loss,accuracy,merged_summary_op

def train(x,keep_prob,y_,train_step,loss,accuracy,merged_summary_op,saver):
    print('training Started.....')
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    init_op = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init_op)
        tf.train.write_graph(sess.graph_def,'model','mnist_cnn.pbtxt',True)
        for step in range(NUM_STEP):
            batch = mnist.train.next_batch(BATCH_SIZE)
            if step%100 == 0:
                train_accuracy = accuracy.eval(feed_dict={x:np.asarray(batch[0].reshape(-1,28,28,1),dtype=np.float32)
                                                          ,y_:batch[1],keep_prob:1.0})
                print('step ',step,' training accuracy: ',train_accuracy)
            _,summary = sess.run([train_step,merged_summary_op],
                                feed_dict={x:np.asarray(batch[0].reshape(-1,28,28,1),dtype=np.float32)
                                           ,y_:batch[1],keep_prob:0.5})
            
        saver.save(sess,'/home/hari304/MNIST/model/mnist_cnn.chkp')
        test_accuracy = accuracy.eval(feed_dict={x:mnist.test.images.reshape(-1,28,28,1),y_:mnist.test.labels,keep_prob:1.0})
    print(' test accuracy: ',test_accuracy)
    print('training finished!')
    
def gen_tflite_graph(input_node_names,output_node_name):
    freeze_graph.freeze_graph('/home/hari304/MNIST/model/mnist_cnn.pbtxt', None, False,
        '/home/hari304/MNIST/model/mnist_cnn.chkp', output_node_name, "save/restore_all",
        "save/Const:0", '/home/hari304/MNIST/model/mnist_cnn.pb', True, "")
    
    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('/home/hari304/MNIST/model/mnist_cnn.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())
        
    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, [input_node_names[0]], [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('/home/hari304/MNIST/model/mnist_cnn.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())
        
    converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph('/home/hari304/MNIST/model/mnist_cnn.pb',
                                                                [input_node_names[0]],
                                                                [output_node_name],input_shapes={'input':[1,28,28,1]})
    tflite_model = converter.convert()
    open('/home/hari304/MNIST/model/mnist_cnn.tflite','wb').write(tflite_model)
    print('TFLITE graph saved!')
    

In [6]:
input_node_name = 'input'
keep_prob_node_name = 'keep_prob'
output_node_name = 'output'
x,keep_prob,y_ = model_input(input_node_name,keep_prob_node_name)
train_step,loss,accuracy,merged_summary_op = build_model(x,keep_prob,y_,output_node_name)
saver = tf.train.Saver()
train(x,keep_prob,y_,train_step,loss,accuracy,merged_summary_op,saver)
gen_tflite_graph([input_node_name,keep_prob_node_name], output_node_name)

training Started.....
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
step  0  training accuracy:  0.1
step  100  training accuracy:  0.92
step  200  training accuracy:  0.89
step  300  training accuracy:  0.92
step  400  training accuracy:  0.97
step  500  training accuracy:  0.95
step  600  training accuracy:  0.97
step  700  training accuracy:  0.97
step  800  training accuracy:  0.96
step  900  training accuracy:  0.98
step  1000  training accuracy:  0.98
step  1100  training accuracy:  0.98
step  1200  training accuracy:  0.99
step  1300  training accuracy:  0.97
step  1400  training accuracy:  0.97
step  1500  training accuracy:  0.98
step  1600  training accuracy:  0.97
step  1700  training accuracy:  1.0
step  1800  training accuracy:  0.98
step  1900  training accuracy:  0.99
 test accuracy:  0.9866
training finished!
INFO:tensorflow