In [1]:
import numpy as np
import tensorflow as tf

PATCH_SIZE = 7
PATCH_PIXELS = 7*7

BATCH_SIZE = 1000

NUM_THREADS = 4

def read(filenames):
    queue = tf.train.string_input_producer(filenames)

    with tf.name_scope('reader'):
        reader = tf.TFRecordReader()
        _, example = reader.read(queue)
    
    with tf.name_scope('features'):
        features = tf.parse_single_example(example, features={
            'us_patch': tf.FixedLenFeature([], tf.string),
            'mr_patch': tf.FixedLenFeature([], tf.string),
        })
        with tf.name_scope('decode'):
            image_raw = tf.decode_raw(features['mr_patch'], tf.float64)
            label_raw = tf.decode_raw(features['us_patch'], tf.float64) 
            image = tf.cast(tf.reshape(image_raw, [PATCH_SIZE, PATCH_SIZE, 1]), tf.float32)
            label = tf.cast(tf.reshape(label_raw, [PATCH_SIZE, PATCH_SIZE, 1]), tf.float32)
    
    return tf.train.batch([image, label],
        batch_size=BATCH_SIZE,
        num_threads=NUM_THREADS,
        capacity=1000+3*BATCH_SIZE)

In [2]:
test_files = ['th-30/05.tfrecord']
train_files = ['th-30/13.tfrecord']

def count_records(filenames):
    count = lambda f: np.sum(1 for _ in tf.python_io.tf_record_iterator(f))
    return [count(f) for f in filenames]

test_count, train_count = count_records(test_files), count_records(train_files)
test_count, train_count

([48974], [74709])

In [3]:
def summarize(name, variable):
    with tf.name_scope(name):
        with tf.name_scope('mean'):
            mean = tf.reduce_mean(variable)
        tf.summary.scalar('mean', mean)

        with tf.name_scope('stddev'):
            stddev = tf.sqrt(tf.reduce_mean(tf.square(variable - mean)))
        tf.summary.scalar('stddev', stddev)    
        
        tf.summary.scalar('min', tf.reduce_min(variable))
        tf.summary.scalar('max', tf.reduce_max(variable))
        tf.summary.histogram('histogram', variable)
            
    return variable

In [4]:
with tf.name_scope('input'):
    image = tf.placeholder(shape=[None, PATCH_SIZE, PATCH_SIZE, 1], dtype=tf.float32, name='image')
    label = tf.placeholder(shape=[None, PATCH_SIZE, PATCH_SIZE, 1], dtype=tf.float32, name='label')

def create_weight(name, shape):
    return summarize(name, tf.Variable(tf.truncated_normal(shape=shape), name=name))

def create_bias(name, shape):
    return summarize(name, tf.Variable(tf.constant(.1, shape=shape), name=name))
   
def create_conv(placeholder, weight):
    return tf.nn.conv2d(placeholder, weight, strides=[1, 1, 1, 1], padding='SAME')
    
with tf.name_scope('model'):
    conv1_weight = create_weight('conv1_weight', [3, 3, 1, 3])
    conv1_bias = create_bias('conv1_bias', [3])
    conv1 = create_conv(image, conv1_weight) + conv1_bias
        
    conv2_weight = create_weight('conv2_weight', [1, 1, 3, 1])
    conv2 = create_conv(conv1, conv2_weight)
    
    tf.summary.histogram('conv1', conv1)
    tf.summary.histogram('conv2', conv2)

    with tf.name_scope('cost'):
        cost = tf.nn.l2_loss(label-conv2)
        tf.summary.scalar('cost', cost)
        
    with tf.name_scope('train'):
        train = tf.train.AdamOptimizer(.0001).minimize(cost)

In [None]:
import matplotlib.cm as cm
import matplotlib.pyplot as plt

def imshow(image):
    plt.imshow(image, cmap=cm.gray, interpolation='none')
    plt.show()

In [None]:
batch = read(train_files)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    writer = tf.summary.FileWriter('/tmp/mrtous', sess.graph)
    merged = tf.summary.merge_all()
    
    try:
        step = 0
        epoch = 0
        
        while not coord.should_stop():
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            
            _, summary = sess.run([train, merged], feed_dict={
                image: batch[0].eval(),
                label: batch[1].eval(),
            }, options=run_options, run_metadata=run_metadata)

            print('epoch: {}, step: {}'.format(epoch, step))
            writer.add_run_metadata(run_metadata, 'epoch{}step{}'.format(epoch, step))
            writer.add_summary(summary, step)
            
            if np.sum(train_count) % BATCH_SIZE < step:
                step = 0
                epoch += 1
            else:
                step += 1
                
    except tf.errors.OutOfRangeError:
        print('done with training')
    finally:
        writer.close()
        
        coord.request_stop()
        coord.join(threads)

epoch: 0, step: 0
epoch: 0, step: 1
epoch: 0, step: 2
epoch: 0, step: 3
epoch: 0, step: 4
epoch: 0, step: 5
epoch: 0, step: 6
epoch: 0, step: 7
epoch: 0, step: 8
epoch: 0, step: 9
