## loading tfrecords & Parsing 

In [2]:
import tensorflow as tf
import numpy as np 
import matplotlib.pyplot as plt 
%matplotlib inline

In [10]:
# Parser for TFRecord
def parser(serialized_example):
    features = {
        'weight' : tf.FixedLenFeature([1], tf.float32),
        'img' : tf.FixedLenFeature([250, 550, 3], tf.string)
    }
    parsed_feature = tf.parse_single_example(serialized_example, features)
    weight = parsed_feature['weight']
    img = parsed_feature['img']
    return weight, img

tf.reset_default_graph()

train =  "../sample_image/preprocessed/train.tfrecord"
test = "../sample_image/preprocessed/test.tfrecord"

# Dataset
train_dataset = tf.data.TFRecordDataset(train).map(parser)
train_dataset = train_dataset.batch(24).shuffle(777)
test_dataset = tf.data.TFRecordDataset(test).map(parser)
test_dataset = test_dataset.batch(24).shuffle(777)

itr = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
weight, img = itr.get_next()

img = tf.reshape(img, [-1, 250, 550, 3]) # 차원 주의 !
img = tf.cast(img, tf.float32)

weight = tf.reshape(weight, [-1]) # 차원 주의 ! 

train_init_op = itr.make_initializer(train_dataset)
test_init_op = itr.make_initializer(test_dataset)

(TensorShape([Dimension(None), Dimension(1)]),
 TensorShape([Dimension(None), Dimension(250), Dimension(550), Dimension(3)]))

## Modeling 

In [11]:
def model(x, activation, dropout_prob, reuse=False):
    # 10개의 feature map, 3x3 필터사이즈 
    conv1 = tf.layers.conv3d(img, filters=16, kernel_size=3, 
                             padding='SAME', activation=activation, 
                             reuse=reuse, name='conv1')
    # 2x2 maxpooling stride=2 
    pool1 = tf.layers.max_pooling3d(conv1, pool_size=2, strides=2)
    
    conv2 = tf.layers.conv2d(pool1, filters=16, kernel_size=3,
                            padding="SAME", activation=activation,
                            reuse=reuse, name='conv2')
    pool2 = tf.layers.max_pooling2d(conv2, pool_size=2, strides=2)
    
    conv3 = tf.layers.conv3d(pool2, filters=32, kernel_size=3, 
                            padding="SAME", activation=activation,
                            reuse=reuse, name='conv3')
    pool3 = tf.layers.max_pooling3d(conv3, pool_size=3, strides=2)
   
    flat = tf.layers.flatten(pool3)
    
    dropout1 = tf.layers.dropout(flat, rate=dropout_prob)
    fc1 = tf.layers.dense(dropout1, units=100, reuse=reuse, 
                          name='fc1')

    dropout2 = tf.layers.dropout(fc1, rate=dropout_prob)
    out = tf.layers.dense(dropout2, 1, reuse=reuse, name='out')
    
    return out 

train_out = model(img, tf.nn.relu, 0.7)
test_out = model(img, tf.nn.relu, 1, True)

loss = tf.losses.mean_squared_error(weight, train_out)
train_op = tf.train.AdagradOptimizer(1e-4).minimize(loss)

summary = tf.summary.scalar('train_loss', loss)

# test(validation) score
pred = test_out
mse = tf.metrics.mean_squared_error(weight, pred)
    
merged = tf.summary.merge_all()

saver = tf.train.Saver()

ValueError: Tensor("conv1/kernel:0", shape=(3, 3, 3, 16), dtype=float32_ref) must be from the same graph as Tensor("Cast:0", shape=(?, 250, 550, 3), dtype=float32).

In [None]:
with tf.Session() as sess :
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
#     writer = tf.summary.FileWriter('../logs/', sess.graph)
    loss_graph = []
    mse_graph = []
    
    for epoch in range(100):
        sess.run(train_init_op)
        for train_step in range(9999999):
            try :
                _, _loss, _summ = sess.run([train_op, loss, summary])
                loss_graph.append(_loss)
                writer.add_summary(_summ, train_step)
                    
            except tf.errors.OutOfRangeError :
                break
                
                
        sess.run(test_init_op)
        for test_step in range(99999999):
                try:
                    _mse = sess.run(mse)
                    mse_graph.append(_mse)
                except tf.errors.OutOfRangeError:
                    break
        print('epochs : {}, train_loss : {}, val_mse : {}'.format(epoch, _loss, _mse[0]))
        saver.save(sess, "../logs/cnn{}".format(epoch))