In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
import os

tf.logging.set_verbosity(tf.logging.ERROR)

# Build Graph

## Data 읽어오기

In [2]:
def get_data():
    return input_data.read_data_sets("datasets/mnist", one_hot=True)

## Model 정의

### Placeholder : graph 입력부분 정의

In [3]:
def get_inputs():
    # image data입력 부분
    x = tf.placeholder(dtype=tf.float32, shape=[None, 28*28])
    # label data입력 부분
    y = tf.placeholder(dtype=tf.float32, shape=[None, 10])
    
    return x, y

### Model : algorithm 을 graph 연산으로 정의

In [4]:
def get_model(images):
    
    x_image = tf.reshape(images, [-1, 28, 28, 1])
    
    # filter shape : w, h, in_channel, out_channel
    conv1_filters = tf.Variable(tf.random_normal([3, 3, 1, 16], stddev=0.01))
    conv1 = tf.nn.conv2d(x_image, conv1_filters, strides=[1, 1, 1, 1], padding='SAME')
    conv1 = tf.nn.relu(conv1)
    print 'conv1', conv1

    pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    print 'pool1', pool1

    batch, h, w, d = [x.value for x in pool1.get_shape()]    
    flatten = tf.reshape(pool1, [-1, h*w*d])
    print 'flatten', flatten
    
    fc_weights = tf.Variable(tf.random_normal([h*w*d, 10], stddev=0.01))
    fc_bias = tf.Variable(tf.random_normal([10]))
    
    logits = tf.matmul(flatten, fc_weights) + fc_bias
    print 'logits', logits
    return logits

## Loss 정의

In [5]:
def get_loss(logits, labels):
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels))

## Optimizer 정의

In [6]:
def get_optimizer(lr): 
    return tf.train.GradientDescentOptimizer(learning_rate=lr)

## Train Graph 정의

In [7]:
# 1. 데이터 읽기
mnist = get_data()

# 2. 모델과 모델입력부분 만들기
images, labels = get_inputs()
model_out = get_model(images)

# 3. Loss 만들기
loss = get_loss(model_out, labels)

# 4. Optimizer 만들기
optimizer = get_optimizer(lr=0.1)

train_op = optimizer.minimize(loss)

Extracting datasets/mnist/train-images-idx3-ubyte.gz
Extracting datasets/mnist/train-labels-idx1-ubyte.gz
Extracting datasets/mnist/t10k-images-idx3-ubyte.gz
Extracting datasets/mnist/t10k-labels-idx1-ubyte.gz
conv1 Tensor("Relu:0", shape=(?, 28, 28, 16), dtype=float32)
pool1 Tensor("MaxPool:0", shape=(?, 14, 14, 16), dtype=float32)
flatten Tensor("Reshape_1:0", shape=(?, 3136), dtype=float32)
logits Tensor("add:0", shape=(?, 10), dtype=float32)


## (Optional) Metric 정의

In [8]:
prediction = tf.nn.softmax(model_out)
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Run Graph

In [9]:
EPOCHS = 3
BATCH_SIZE = 100
NUM_BATCH_PER_EPOCH = int(mnist.train.images.shape[0]/float(BATCH_SIZE))

## Training

In [10]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

for ep in range(EPOCHS):
    for st in range(NUM_BATCH_PER_EPOCH):
        batch_images, batch_labels = mnist.train.next_batch(BATCH_SIZE)
        _, _acc, _loss = sess.run([train_op, accuracy, loss], feed_dict={images:batch_images, labels:batch_labels})
        
        if st % 100 == 0:
            print '{} Epoch, {} Step : acc({:.4f}), loss({:.4f})'.format(ep, st, _acc, _loss)

0 Epoch, 0 Step : acc(0.0500), loss(2.5984)
0 Epoch, 100 Step : acc(0.8600), loss(0.5017)
0 Epoch, 200 Step : acc(0.8900), loss(0.3413)
0 Epoch, 300 Step : acc(0.9300), loss(0.2371)
0 Epoch, 400 Step : acc(0.9100), loss(0.2648)
0 Epoch, 500 Step : acc(0.8400), loss(0.5781)
1 Epoch, 0 Step : acc(0.9600), loss(0.1924)
1 Epoch, 100 Step : acc(0.9000), loss(0.2607)
1 Epoch, 200 Step : acc(0.8900), loss(0.3644)
1 Epoch, 300 Step : acc(0.9100), loss(0.2403)
1 Epoch, 400 Step : acc(0.9100), loss(0.3497)
1 Epoch, 500 Step : acc(0.8800), loss(0.4741)
2 Epoch, 0 Step : acc(0.8600), loss(0.4483)
2 Epoch, 100 Step : acc(0.9300), loss(0.2329)
2 Epoch, 200 Step : acc(0.9300), loss(0.2527)
2 Epoch, 300 Step : acc(0.9500), loss(0.2022)
2 Epoch, 400 Step : acc(0.8700), loss(0.3921)
2 Epoch, 500 Step : acc(0.9700), loss(0.2174)
