In [None]:
#coding=utf-8
import  tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import math
import numpy as np

def cifar10_input_stream(records_path):
  reader = tf.TFRecordReader()
  filename_queue = tf.train.string_input_producer([records_path], None)
  _, record_value = reader.read(filename_queue)
  features = tf.parse_single_example(record_value,
    {
      'image_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
    })
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image = tf.reshape(image, [32,32,3])
  image = tf.cast(image, tf.float32)
  label = tf.cast(features['label'], tf.int64)
  return image, label

def normalize_image(image):
  mean= [ 125.30690002,122.95014954,113.86599731]
  std = [ 62.9932518,62.08860397,66.70500946]
  normed_image = (image - mean) / std
  return normed_image

def random_distort_image(image):
  distorted_image = image
  distorted_image = tf.image.pad_to_bounding_box(image, 4, 4, 40, 40)  # pad 4 pixels to each side
  distorted_image = tf.random_crop(distorted_image, [32, 32, 3])
  distorted_image = tf.image.random_flip_left_right(distorted_image)
  return distorted_image

def make_train_batch(train_records_path, batch_size):
  train_image, train_label = cifar10_input_stream(train_records_path)
  train_image = normalize_image(train_image)
  train_image = random_distort_image(train_image)
  train_image_batch, train_label_batch = tf.train.shuffle_batch([train_image, train_label], batch_size=batch_size, num_threads=4,capacity=50000,min_after_dequeue=1000)
  return train_image_batch, train_label_batch

def make_validation_batch(test_records_path, batch_size):
  test_image, test_label = cifar10_input_stream(test_records_path)
  test_image = normalize_image(test_image)
  test_image_batch, test_label_batch = tf.train.batch(
    [test_image, test_label], batch_size=batch_size, num_threads=1,
    capacity=10000)
  return test_image_batch, test_label_batch


def one_hot_embedding(label, n_classes):
  """
  One-hot embedding
  Args:
    label: int32 tensor [B]
    n_classes: int32, number of classes
  Return:
    embedding: tensor [B x n_classes]
  """
  embedding_params = np.eye(n_classes, dtype=np.float32)
  params = tf.constant(embedding_params)
  embedding = tf.gather(params, label)
  return embedding



def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.01)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.01, shape=shape)
    return tf.Variable(initial)


def conv2d(input, in_features, out_features, kernel_size, stride):
    W = weight_variable([kernel_size, kernel_size, in_features, out_features])
    return tf.nn.conv2d(input, W, [1, stride, stride, 1], padding='SAME')


def basic_block(input, in_features, out_features, stride, phase_train):
    if stride == 1:
        shortcut = input
    else:
        shortcut = tf.nn.avg_pool(input, [1, stride, stride, 1], [1, stride, stride, 1], 'VALID')
        shortcut = tf.pad(shortcut, [[0, 0], [0, 0], [0, 0],[(out_features - in_features) // 2, (out_features - in_features) // 2]])
    current = conv2d(input, in_features, out_features, 3, stride)
    current = tf.contrib.layers.batch_norm(current, scale=True, is_training=phase_train, updates_collections=None)
    current = tf.nn.relu(current)
    current = conv2d(current, out_features, out_features, 3, 1)
    current = tf.contrib.layers.batch_norm(current, scale=True, is_training=phase_train, updates_collections=None)
    # No final relu as per http://torch.ch/blog/2016/02/04/resnets.html
    return current + shortcut


def block_stack(input, in_features, out_features, stride, depth, phase_train):
    current = basic_block(input, in_features, out_features, stride, phase_train)
    for _d in range(depth - 1):
        current = basic_block(current, out_features, out_features, 1, phase_train)
    return current

phase_train = tf.placeholder(tf.bool, name='phase_train')
learning_rate = tf.placeholder(tf.float32, name='learning_rate')
train_image_batch, train_label_batch = make_train_batch('F:/raochuan_code/my_ResNet_v2/data/data/train.tf', batch_size = 128)
val_image_batch, val_label_batch = make_validation_batch('F:/raochuan_code/my_ResNet_v2/data/data/test.tf', batch_size = 100)
image_batch, label_batch = control_flow_ops.cond(phase_train,lambda: (train_image_batch, train_label_batch),lambda: (val_image_batch, val_label_batch))
# logits = residual_net(image_batch, 3, 10, phase_train)
targets = one_hot_embedding(label_batch, 10)



current = conv2d(image_batch, 3, 16, 3, 1)
current = tf.nn.relu(current)

# dimension is 32x32x16
current = block_stack(current, 16, 16, 1, 6, phase_train)
current = block_stack(current, 16, 32, 2, 6, phase_train)
# dimension is 16x16x32
current = block_stack(current, 32, 64, 2, 6, phase_train)
# dimension is 8x8x64

current = tf.reduce_mean(current, reduction_indices=[1, 2], name="avg_pool")
final_dim = 64
current = tf.reshape(current, [-1, final_dim])
Wfc = weight_variable([final_dim, 10])
bfc = bias_variable([10])
ys_ = tf.nn.softmax(tf.matmul(current, Wfc) + bfc)

cross_entropy = -tf.reduce_mean(targets * tf.log(ys_ + 1e-12))

# entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(targets,logits),name='entropy_loss')
# entropy_loss = tf.reduce_mean(targets*tf.log(logits+1e-12),name='entropy_loss')
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(ys_, 1), tf.argmax(targets, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)

for epoch in range(1,81):
    lr = 0.1
    if epoch == 30:lr=0.01
    if epoch == 60:lr=0.001
    for turn in range(500):
        batch_res = sess.run([train_step, accuracy], feed_dict={phase_train: True, learning_rate: lr})

    # print('Train accuracy = %f' % batch_res[1])
    n_val_samples = 10000
    val_batch_size = 100
    n_val_batch = int(n_val_samples / val_batch_size)
    val_logits = np.zeros((n_val_samples, 10), dtype=np.float32)
    val_labels = np.zeros((n_val_samples), dtype=np.int64)
    val_losses = []
    for i in range(n_val_batch):
        fetches = [ys_, label_batch]
        session_outputs = sess.run(fetches, feed_dict={phase_train: False})
        val_logits[i * val_batch_size:(i + 1) * val_batch_size, :] = session_outputs[0]
        val_labels[i * val_batch_size:(i + 1) * val_batch_size] = session_outputs[1]
    pred_labels = np.argmax(val_logits, axis=1)
    val_accuracy = np.count_nonzero(pred_labels == val_labels) / n_val_samples
    print('第%d 次测试，训练准确率：%f，测试准确率：%f' % (epoch, batch_res[1], val_accuracy))
# save_path = saver.save(sess,'./model/model.ckpt')