# Train a Simplified LeNet5 Model to classify mnist data

LeNet5 is a well know model which was introduced in paper "Gradient-based learning applied to document recognition" by LeCun in 1998. In the paper it is used to classify the MNIST handscripts numbers. Here we will train a simplified LeNet5 model to classify CIFAR-10 data.


## Section 1 Define Super Training Parameters

- epoch: define the iteration of the train
- batch_size: define the train batch size. It depends on how large the memory is. CIFAR-10 is a very small images. 200 - 500 should be good. 
- test_size: define the test batch size.  
- learn rate (lr): The start learn rate for Agagrad
- keep_prob: the probability of the training parameter
- augument: To have a better training effect, the image augument is always True

In [1]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1" 
sys.path.append('../common/')
sys.path.append('../cifar10/')

import tensorflow as tf
import numpy as np
from train_log import train_log
from tensorflow.examples.tutorials.mnist import input_data
from lenet5 import LeNet5

FLAGS = tf.flags.FLAGS

try:
    #Super parameter definition
    tf.flags.DEFINE_string('f', '', 'kernel')
    tf.flags.DEFINE_integer('epoch', 50000, 'epoch')
    tf.flags.DEFINE_integer('batch_size',250, 'batch size')
    tf.flags.DEFINE_integer('test_size', 250, 'test size')
    tf.flags.DEFINE_float('lr', 0.01, 'learning rate')
    tf.flags.DEFINE_float('drop_rate', 0.5, 'drop out rate for drop lay')
    #Other training parameter                        
    tf.flags.DEFINE_float('ckpt_frequency', 125, 'frequency to save checkpoint')
    tf.flags.DEFINE_boolean('restore', False, 'restore from checkpoint and run test')
    print('parameters were defined.')
except:
    print('parameters have been defined.')

CONTINUE = 40000
RUN = 1
print("RUN "+str(RUN)+": " + str(CONTINUE) + "-" + str(FLAGS.epoch))
print("batch size=",FLAGS.batch_size, "learn Rate =",FLAGS.lr, "drop rate=", FLAGS.drop_rate)

parameters were defined.
RUN 1: 40000-50000
batch size= 250 learn Rate = 0.01 drop rate= 0.5


## Section 2. Generate Checkpoint dir and Log dir

- Checkpoint dir is saved in variable **../Le-Net5-Log/Le-Net5_CLASS/ckpt_RUN**, if the dir doesn't exist then create it. 
- Log file dir is saved in variable **../Le-Net5-Log/Le-Net5_CLASS/log_RUN**, if the dir doesn't exist then create it. 
- data_path is the position of the CIFAR-10 image data
The reason to save the model and log outside the project is to avoid effect the git code management.

In [2]:
ckpt_dir = './mnist_study_log/ckpt_'+str(RUN)+'/'
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

log_dir = './mnist_study_log/log_'+str(RUN)+'/'
log = train_log(log_dir)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log.write_file('configuration',['epoch='+str(FLAGS.epoch),
                                'batch_size ='+str(FLAGS.batch_size), 
                                'lr ='+str(FLAGS.lr),
                                'drop_rate ='+str(FLAGS.drop_rate),
                                'ckpt_frequency = '+str(FLAGS.ckpt_frequency)])

data_path = 'mnist_data'
if not os.path.exists(data_path):
    print('The data path doesn\'t exist. Please check if it is a correct data path.')

mnist = input_data.read_data_sets(data_path, one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting mnist_data/t10k-images-idx3-ubyte.gz
Extracting mnist_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


## Section 3. Simplified LeNet5 
This is a simplied LeNet5 model. A lot of details are ignored in the model.Below is the architecture of the LeNet5 we implemented. It is a little bit different from the original one as some detail design were removed.  

![image.png](attachment:image.png)


## Section 4. Build the calculation graph

### 4.1 Input layer 

Use the data feeder to provide the training data. Therefore we define a placeholder with the same structure of the input data. 
The CIFAR-10 data is 60000 RGB images with each size is 32x32. The data structure should be[batchsize, 32,32,3]. Input channels are 3.  

In [3]:
with tf.name_scope('input'):
    x = tf.placeholder(tf.float32, [None,32,32,1], name='x_input')
    #x_image = tf.reshape(x, [-1, 28, 28, 1])
    drop_rate = tf.placeholder(tf.float32, name='drop_rate')
    y_ = tf.placeholder(tf.int64, [None,10], name='labels')
    label = tf.argmax(y_,1)

### 4.2 Deifne the LeNet5 object

Use the previously defined LeNet5 class to build LeNet5 network.
y is the output of the LeNet5 network. 

In [4]:
with tf.name_scope('prediction'):
    le_net5 = LeNet5(x, drop_rate)
    y = le_net5.prediction    

norm_0:  (?, 32, 32, 1)
conv_1:  (?, 28, 28, 6)
pool_1:  (?, 14, 14, 6)
conv_2:  (?, 10, 10, 16)
pool_2:  (?, 5, 5, 16)
conv_3: (?, 1, 1, 120)
flat_1: (?, 120)
fc_2  (?, 84)
fc_3:  (?, 10)
drop_out:  (?, 10)
prediction:  (?, 10)


### 4.3 Calculate the cross entropy as the loss
Use the cross entropy as the loss. 
**cross entropy** is normally used as the loss of the network. 
$$cross\_entropy(output, label) = \sum_i{output_i*log(label_i)}$$

tensorflow function **sparse_softmax_cross_entropy_with_logits** calculate the cross entropy with the logits which is one hot output from the network and the label which is a interger number.


In [5]:
with tf.name_scope('cross_entropy'):
    cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,
                                                                                  labels=label, 
                                                                                  name="cross_entropy_per_example"))
    #cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))

### 4.4 Use Adagrad to minimize the loss

In [6]:
with tf.name_scope('train_step'):
    train_step = tf.train.AdagradOptimizer(FLAGS.lr).minimize(cross_entropy)
    #train_step = tf.train.SGDOptimizer(FLAGS.lr).minimize(cross_entropy)

### 4.5 Calculate the reduce mean as the accuracy

In [7]:
with tf.name_scope('accuracy'):
    prediction =tf.argmax(y, 1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction,label), tf.float32))

## Now, let's start the training...

In [8]:
import time
import matplotlib.pyplot as plt

saver = tf.train.Saver(max_to_keep=1)
#GPU memory usage 40%
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.35)

with tf.Session(config=tf.ConfigProto(log_device_placement=True, gpu_options=gpu_options)) as sess:
    sess.run(tf.global_variables_initializer())
    tf.get_default_graph().finalize()
    if CONTINUE != 0:
        model_file=tf.train.latest_checkpoint(ckpt_dir)
        saver.restore(sess,model_file)
    for i in range(CONTINUE, FLAGS.epoch):
        train_image, train_label = mnist.train.next_batch(FLAGS.batch_size)
        train_image = np.array(train_image).reshape(FLAGS.batch_size,28,28,1)
        train_image = np.pad(train_image,((0,0),(2,2),(2,2),(0,0)),'constant',constant_values=(np.min(train_image),np.min(train_image)))        
        '''
        for im in train_image:
            im = im.reshape(32,32)
            plt.imshow(im)
            plt.show()
        '''
        pred_value, loss, _,accuracy_rate = sess.run([y, cross_entropy, train_step, accuracy], 
                                         feed_dict={drop_rate: FLAGS.drop_rate, x:train_image, y_:train_label})
        print('.',end='')
        #print(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime())+' iter '+str(i)+',Train accuracy:'+str(round(accuracy_rate*100,2))+'%')
        #log.add_log('logits',i,pred_value)
        log.add_log('train_accuracy',i, accuracy_rate)
        log.add_log('train_loss',i, loss)
        if (i+1) % FLAGS.ckpt_frequency == 0:  #保存预测模型
            saver.save(sess,ckpt_dir+'cifar10_'+str(i+1)+'.ckpt',global_step=i+1)
            acc_accuracy = 0
            for j in range(int(10000/FLAGS.test_size)):                    
                test_image, test_label = mnist.test.next_batch(FLAGS.test_size)
                test_image = np.array(test_image).reshape(FLAGS.test_size,28,28,1)
                test_image = np.pad(test_image,((0,0),(2,2),(2,2),(0,0)),'constant',constant_values=(np.min(test_image),np.min(test_image)))        
                pred_value, loss, _,accuracy_rate = sess.run([y, cross_entropy, train_step, accuracy], 
                                         feed_dict={drop_rate: FLAGS.drop_rate, x:train_image, y_:train_label})
                accuracy_rate, output = sess.run([accuracy,prediction],
                                                 feed_dict={drop_rate: 0, x:test_image, y_:test_label})
                acc_accuracy += accuracy_rate
                #log.add_log('test_batch_accuracy',i, accuracy_rate)
                #log.add_log('test_index',i, test_index)
                #log.add_log('output',i, output)
            accuracy_rate = acc_accuracy/10000*FLAGS.test_size
            print()
            print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + 
                  ' iter ' + str(i) + ', Test accuracy:' +str(round(accuracy_rate*100,2))+'%')
            log.add_log('test_accuracy',i, accuracy_rate)
            log.SaveToFile()

INFO:tensorflow:Restoring parameters from ./mnist_study_log/ckpt_1/cifar10_40000.ckpt-40000
.............................................................................................................................
2020-03-28 14:11:04 iter 40124, Test accuracy:98.12%
.............................................................................................................................
2020-03-28 14:11:16 iter 40249, Test accuracy:98.21%
.............................................................................................................................
2020-03-28 14:11:29 iter 40374, Test accuracy:98.11%
.............................................................................................................................
2020-03-28 14:11:41 iter 40499, Test accuracy:97.98%
.............................................................................................................................
2020-03-28 14:11:53 iter 40624, Test accuracy:98.03%
.............