In [12]:
import os 
import tensorflow as tf

from typing import Dict, List, Tuple

In [2]:
fnames_dataset = [
    "/Users/jonathanj/workspace/rainy-image-dataset/rain.tfrecord"
]

In [81]:
def _parse(example_proto: tf.train.Example) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
    def _fixed(t, default):
        return tf.FixedLenFeature((), t, default_value=default)
    
    keys_to_features = {
        'image/in/filename': _fixed(tf.string, ''),
        'image/out/filename': _fixed(tf.string, ''),
        'image/height': _fixed(tf.int64, 0),
        'image/width': _fixed(tf.int64, 0),
        'image/in/contents': _fixed(tf.string, ''),
        'image/out/contents': _fixed(tf.string, ''),
    }
    parsed = tf.parse_single_example(example_proto, keys_to_features)
    
    def _decode(img):
        _img = tf.image.decode_image(img, dtype=tf.dtypes.float32)
        _img = tf.image.crop_to_bounding_box(_img, 0, 0, 300, 300)
        return _img
    
    return _decode(parsed['image/in/contents']), _decode(parsed['image/out/contents'])


In [82]:
def input_fn(fs: List[str]) -> tf.data.Dataset:
    dataset = tf.data.TFRecordDataset(fs)
    dataset = dataset.map(_parse)
    dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(1)
    
    return dataset

In [90]:
def model_fn(
    features, 
    labels, 
    mode: tf.estimator.ModeKeys,
    params,
) -> tf.estimator.EstimatorSpec:
    inputs = features
    expected_outputs = labels
    #tf.summary.image("inputs", inputs)
    #tf.summary.image("expected_outputs", expected_outputs)
    
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
            3, 
            (16, 16),
            input_shape=(300, 300, 3),
            use_bias=True,
            activation=tf.nn.tanh,
            padding='same',
        ),
        tf.keras.layers.Conv2D(
            512,
            (1, 1),
            use_bias=True,
            activation=tf.nn.tanh,
        ),
        tf.keras.layers.Conv2D(
            3,
            (8, 8),
            use_bias=True,
            padding='same',
        ),
    ])
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        pred = model(inputs, training=True)
        norm = tf.norm(expected_outputs - pred, ord="fro", axis=[-2, -1])
        loss = tf.reduce_mean(norm, 1)
        optimizer = tf.train.GradientDescentOptimizer(params["learn_rate"])

        train_op = optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step())

        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            train_op=train_op,
        )
    else:
        raise tf.errors.UnimplementedError

# Playground

In [91]:
regressor = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir='/tmp/derain-ckpt',
    config=None,
    params={'learn_rate': 0.01},
)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/derain-ckpt', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f6624374f28>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


In [None]:
regressor.train(
    input_fn=lambda: input_fn(['/home/jjin/proj/rainy-image-dataset/rain.tfrecord'])
)

INFO:tensorflow:Calling model_fn.
Instructions for updating:
Use tf.cast instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/derain-ckpt/model.ckpt.


In [64]:
dataset = input_fn("/home/jjin/proj/rainy-image-dataset/rain.tfrecord")
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

In [66]:
sess = tf.Session()
elem = sess.run(next_element)