Skip to content

Commit

Permalink
Update model.
Browse files Browse the repository at this point in the history
  • Loading branch information
nex3z committed Mar 25, 2018
1 parent 365a690 commit 12f2585
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 68 deletions.
Empty file removed __init__.py
Empty file.
6 changes: 3 additions & 3 deletions convert.py
Expand Up @@ -2,7 +2,7 @@

import tensorflow as tf
from tensorflow.python.framework import graph_util
from model import mnist as model
import train

# A workaround to fix an import issue
# see https://github.com/tensorflow/tensorflow/issues/15410#issuecomment-352189481
Expand All @@ -20,9 +20,9 @@ def main():

def convert(options):
# Create a model to classify single image
x_single = tf.placeholder(tf.float32, [1, model.IMAGE_SIZE, model.IMAGE_SIZE, model.IMAGE_CHANNEL_NUM],
x_single = tf.placeholder(tf.float32, [1, train.IMAGE_SIZE, train.IMAGE_SIZE, train.IMAGE_CHANNEL_NUM],
name="input_single")
y_single = model.inference(x_single)
y_single = train.inference(x_single)
output_single = tf.identity(tf.nn.softmax(y_single, axis=1), name="output_single")

with tf.Session() as sess:
Expand Down
Empty file removed model/__init__.py
Empty file.
63 changes: 0 additions & 63 deletions model/mnist.py

This file was deleted.

60 changes: 58 additions & 2 deletions train.py
Expand Up @@ -6,7 +6,17 @@
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util

from model import mnist as model

IMAGE_SIZE = 28
IMAGE_CHANNEL_NUM = 1
CONV_1_SIZE = 6
CONV_1_DEPTH = 6
CONV_2_SIZE = 5
CONV_2_DEPTH = 12
CONV_3_SIZE = 4
CONV_3_DEPTH = 24
FC_SIZE = 200
OUTPUT_SIZE = 10

LEARNING_RATE_MAX = 0.003
LEARNING_RATE_MIN = 0.0001
Expand All @@ -29,11 +39,57 @@ def main():
train(mnist_data, options)


def get_weight(shape):
return tf.get_variable("weight", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))


def get_bias(shape):
return tf.get_variable("bias", shape, initializer=tf.constant_initializer(0.0))


def conv2d(input_tensor, weight, stride):
return tf.nn.conv2d(input_tensor, weight, strides=[1, stride, stride, 1], padding="SAME")


def inference(input_tensor, regularizer=None):
with tf.variable_scope("layer_1_conv"):
conv_1_weight = get_weight([CONV_1_SIZE, CONV_1_SIZE, IMAGE_CHANNEL_NUM, CONV_1_DEPTH])
conv_1_bias = get_bias([CONV_1_DEPTH])
conv_1 = conv2d(input_tensor, conv_1_weight, stride=1)
conv_1_activation = tf.nn.relu(tf.nn.bias_add(conv_1, conv_1_bias))
with tf.variable_scope("layer_2_conv"):
conv_2_weight = get_weight([CONV_2_SIZE, CONV_2_SIZE, CONV_1_DEPTH, CONV_2_DEPTH])
conv_2_bias = get_bias([CONV_2_DEPTH])
conv_2 = conv2d(conv_1_activation, conv_2_weight, stride=2)
conv_2_activation = tf.nn.relu(tf.nn.bias_add(conv_2, conv_2_bias))
with tf.variable_scope("layer_3_conv"):
conv_3_weight = get_weight([CONV_3_SIZE, CONV_3_SIZE, CONV_2_DEPTH, CONV_3_DEPTH])
conv_3_bias = get_bias([CONV_3_DEPTH])
conv_3 = conv2d(conv_2_activation, conv_3_weight, stride=2)
conv_3_activation = tf.nn.relu(tf.nn.bias_add(conv_3, conv_3_bias))
shape = conv_3_activation.get_shape().as_list()
nodes = shape[1] * shape[2] * shape[3]
conv_3_activation_reshaped = tf.reshape(conv_3_activation, [-1, nodes])
with tf.variable_scope("layer_4_fc"):
w4 = get_weight([nodes, FC_SIZE])
if regularizer is not None:
tf.add_to_collection("losses", regularizer(w4))
b4 = get_bias([FC_SIZE])
a4 = tf.nn.relu(tf.matmul(conv_3_activation_reshaped, w4) + b4)
with tf.variable_scope("layer_5_fc"):
w5 = get_weight([FC_SIZE, OUTPUT_SIZE])
if regularizer is not None:
tf.add_to_collection("losses", regularizer(w5))
b5 = get_bias([OUTPUT_SIZE])
logits = tf.matmul(a4, w5) + b5
return logits


def train(mnist_data, options):
x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
logits = model.inference(x, regularizer)
logits = inference(x, regularizer)

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=tf.argmax(y_, 1))
loss = tf.reduce_mean(cross_entropy) + tf.add_n(tf.get_collection("losses"))
Expand Down

0 comments on commit 12f2585

Please sign in to comment.