A (WIP) TensorFlow reproduction of Fu, Huang, Ding, Liao, and Paisley's method for single-image rain removal (https://arxiv.org/abs/1609.02087)
Switch branches/tags
Nothing to show
Clone or download

README.org

Reimplementing a deep network architecture for single-image rain removal

Note: This project is a work-in-progress. As such, there are bound to be some rough edges and/or missing functionality here and there. For now, I encourage readers and followers to take what’s currently here at face value and follow along for updates as the project progresses. Thanks for the interest! It means a lot to me.

Abstract

We present a re-implementation of the DerainNet method for single-image rain removal in the literate programming style. Our approach here deviates slightly from the authors’ implementation in a variety of ways, most notably of which is the use of TensorFlow for all stages in the workflow: data preparation; model definition; model training; and so on.

For our purposes, we will use a mirror of the original authors’ dataset hosted at github.com/jinnovation/rainy-image-dataset.

[1/2] Development

We’ll split the development process roughly into the following stages:

  • Data importing/preparation;
  • Model training.

Data importing and preparation

For starters, we’ll need to implement the gateway into our training and evaluation data. We’ll do so using the tf.data.Dataset API; see TensorFlow > Develop > Programmer’s Guide > Importing Data for more details.

These contents will comprise our call_get-tangled-fname() module.

import glob
import os
import tensorflow as tf

The dataset directory is organized into two folders:

  • ground truth images; and
  • rainy images.
GROUND_TRUTH_DIR = "ground truth"
RAINY_IMAGE_DIR = "rainy image"

These correspond to our model’s expected outputs, as well as the inputs, respectively. Likewise, each “ground truth” image corresponds to one or more “rainy” images according to some “index” value – ground truth/1.jpg to rainy image/1_{1,2,3}.jpg, ground truth/100.jpg to rainy image/100_{1,2,3,4,5}, and so on.

With this topology in mind, we can define a function to retrieve all index values within the data directory:

def _get_indices(data_dir):
    """Get indices for input/output association.

    Args:
    data_dir: Path to the data directory.

    Returns:
    indices: List of numerical index values.

    Raises:
    ValueError: If no data_dir or no ground-truth dir.
    """

    if not tf.gfile.Exists(os.path.join(data_dir, GROUND_TRUTH_DIR)):
        raise ValueError("Failed to find ground-truth directory.")

    return [
        os.path.splitext(os.path.basename(f))[0]
        for f in glob.glob(os.path.join(data_dir, GROUND_TRUTH_DIR, "*.jpg"))
    ]

Likewise, we can define functions to retrieve the filenames of all input and output files corresponding to a given set of indices.

def _get_input_files(data_dir, indices):
    """Get input files from indices.

    Args:
    data_dir: Path to the data directory.
    indices: List of numerical index values.

    Returns:
    Dictionary, keyed by index value, valued by string lists containing
    one or more filenames.

    Raises:
    ValueError: If no rainy-image dir.
    """

    directory = os.path.join(data_dir, RAINY_IMAGE_DIR)
    if not tf.gfile.Exists(directory):
        raise ValueError("Failed to find rainy-image directory.")

    return {
        i: glob.glob(os.path.join(directory, "{}_[0-9]*.jpg".format(i)))
        for i in indices
    }
def _get_output_files(data_dir, indices):
    """Get output files from indices.

    Args:
    data_dir: Path to the data directory.
    indices: List of numerical index values.

    Returns:
    outputs: Dictionary, keyed by index value, valued by stsring lists
    containing one or more filenames.

    Raises:
    ValueError: If no ground-truth dir.
    """

    directory = os.path.join(data_dir, GROUND_TRUTH_DIR)
    if not tf.gfile.Exists(directory):
        raise ValueError("Failed to find ground-truth directory.")

    return {
        i: os.path.join(directory, "{}.jpg".format(i)) for i in indices
    }

Now we arrive at our first deviation from the author’s implementation. The authors, in their paper, don’t seem to perform any standardized cropping on their images, relying on the inputs to all be identical in resolution or, at most, flipped, e.g. 384x512 and 512x384. Here, to ease integration with TensorFlow, we settle on a common length to truncate inputs and outputs by along both dimensions.

IMAGE_SIZE = 384

Finally, we have our canonical dataset-creation interface, concluding the first part of our project.

def dataset(data_dir, indices=None):
    """Construct dataset for rainy-image evaluation.

    Args:
    data_dir: Path to the data directory.
    indices: The input-output pairings to return. If None (the default), uses
    indices present in the data directory.

    Returns:
    dataset: Dataset of input-output images.
    """

    if not indices:
        indices = _get_indices(data_dir)

    fs_in = _get_input_files(data_dir, indices)
    fs_out = _get_output_files(data_dir, indices)

    ins = [
        fname for k, v in iter(sorted(fs_in.items()))
        for fname in v if k in indices
    ]

    outs = [v for sublist in [
        [fname] * len(fs_in[k])
        for k, fname in iter(sorted(fs_out.items()))
        if k in indices
    ] for v in sublist]

    def _parse_function(fname_in, fname_out):
        def _decode_resize(fname):
            f = tf.read_file(fname)
            contents = tf.image.decode_jpeg(f)
            resized = tf.image.resize_image_with_crop_or_pad(
                contents, IMAGE_SIZE, IMAGE_SIZE,
            )
            casted = tf.cast(resized, tf.float32)
            return casted

        return _decode_resize(fname_in), _decode_resize(fname_out)

    dataset = tf.data.Dataset.from_tensor_slices(
        (tf.constant(ins), tf.constant(outs)),
    ).map(_parse_function)

    return dataset

Model Training

Dependencies

import tensorflow as tf
import logging
from rainy_image_input import dataset, IMAGE_SIZE

Execution Environment

tf.app.flags.DEFINE_string("checkpoint_dir", "/tmp/derain-checkpoint",
                           """Directory to write event logs and checkpointing
                           to.""")

tf.app.flags.DEFINE_string("data_dir",
                           "/tmp/derain_data",
                           """Path to the derain data directory.""")

tf.app.flags.DEFINE_integer("batch_size",
                            128,
                            """Number of images to process in a batch.""")

tf.app.flags.DEFINE_integer("max_steps",
                            1000000,
                            """Number of training batches to run.""")

LEVEL = tf.logging.DEBUG
FLAGS = tf.app.flags.FLAGS

LOG = logging.getLogger("derain-train")

Defining the model

MODEL_DEFAULT_PARAMS = {
    "learn_rate": 0.01,
}


def model_fn(features, labels, mode, params):
    inputs = features
    expected_outputs = labels
    tf.summary.image("inputs", inputs)
    tf.summary.image("expected_outputs", expected_outputs)
    global_step = tf.train.get_or_create_global_step()

    params = {**MODEL_DEFAULT_PARAMS, **params}

    l = tf.keras.layers
    model = tf.keras.Sequential([
        l.Conv2D(
            3,
            (16, 16),
            input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
            use_bias=True,
            activation=tf.nn.tanh,
            padding="same",
        )
        # # 512 kernels of 16x16x3
        # l.Conv2D(
        #     512,
        #     (16, 16),
        #     input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
        #     use_bias=True,
        #     activation=tf.nn.tanh,
        #     padding="same",
        # ),
        # # 512 kernels of 1x1x512
        # l.Conv2D(
        #     512,
        #     (1, 1),
        #     use_bias=True,
        #     activation=tf.nn.tanh,
        # ),
        # # 3 kernels of 8x8x512 (one for each color channel)
        # l.Conv2D(
        #     3,
        #     (8, 8),
        #     use_bias=True,
        #     padding="same",
        # ),
    ])

    # TODO: handle each of ModeKeys.{EVAL,TRAIN,PREDICT}
    if mode == tf.estimator.ModeKeys.TRAIN:
        predictions = model(inputs, training=True)

        norm = tf.norm(expected_outputs - predictions, ord="fro", axis=[-2, -1])
        loss = tf.reduce_mean(norm, 1)
        # tf.summary.scalar("loss", loss)

        optimizer = tf.train.GradientDescentOptimizer(params["learn_rate"])

        train_op = optimizer.minimize(loss, global_step=global_step)

        return tf.estimator.EstimatorSpec(
            mode,
            loss=loss,
            train_op=train_op,
        )

    raise NotImplementedError

The Rest of the Owl

def train():
    regressor = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=FLAGS.checkpoint_dir,
        # TODO
        config=None,
        params={},
    )

    regressor.train(
        input_fn=lambda: dataset(FLAGS.data_dir, range(1, 20)).batch(1),
        max_steps=FLAGS.max_steps,
    )


def main(argv=None):
    if tf.gfile.Exists(FLAGS.checkpoint_dir):
        LOG.debug("Emptying checkpoint dir")
        tf.gfile.DeleteRecursively(FLAGS.checkpoint_dir)

    LOG.debug("Creating checkpoint dir")
    tf.gfile.MakeDirs(FLAGS.checkpoint_dir)

    train()
if __name__ == "__main__":
    tf.logging.set_verbosity(LEVEL)
    tf.app.run(main)

References

  • http://smartdsp.xmu.edu.cn/derainNet.html
  • X. Fu, J. Huang, D. Zeng, Y. Huang, X. Ding and J. Paisley. ¡°Removing Rain from Single Images via a Deep Detail Network¡±, CVPR, 2017.
  • X. Fu, J. Huang, X. Ding, Y. Liao and J. Paisley. ¡°Clearing the Skies: A deep network architecture for single-image rain removal¡±, IEEE Transactions on Image Processing, vol. 26, no. 6, pp. 2944-2956, 2017.

Appendix

The Dataset

The authors’ rainy image dataset can be found here. Unfortunately, that page was, at the start of this project, unreliable at best; it is now, as of 04/28/2018, entirely unavailable. As such, at the start of this project, I took the liberty of cloning the authors’ dataset; it is available at github.com/jinnovation/rainy-image-dataset.

The Report

This article is written with Emacs and Org mode in the literate programming style.

The Site

The source for this write-up, as well as all tangled source files, is hosted at jinnovation/derain-net. This write-up is generated using Org-mode’s HTML export functionality, as well as the ReadTheOrg theme. All resulting source files were tangled directly from the write-up document.