Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update data reader for compatibility with TensorFlow 2.0 #16

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,17 @@ The following version of the datasets are available:
textures.

### Usage example

To select what dataset to load, instantiate a reader passing the correct
`version` argument. Note that the constructor will set up all the queues used by
the reader. To get tensors call `read` on the data reader passing in the desired
batch size.

```python
import tensorflow as tf
from data_reader import data_reader

root_path = 'path/to/datasets/root/folder'
data_reader = DataReader(dataset='jaco', context_size=5, root=root_path)
data = data_reader.read(batch_size=12)
dataset = data_reader(dataset='jaco',
root=root_path,
batch_size=12,
context_size=5)

with tf.train.SingularMonitoredSession() as sess:
d = sess.run(data)
# Train a Keras model on the dataset
model.fit(dataset)
```

### Download
Expand Down
292 changes: 117 additions & 175 deletions data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,20 @@

"""Minimal data reader for GQN TFRecord datasets."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os
import tensorflow as tf
nest = tf.contrib.framework.nest

AUTOTUNE = tf.data.experimental.AUTOTUNE

DatasetInfo = collections.namedtuple(
'DatasetInfo',
['basepath', 'train_size', 'test_size', 'frame_size', 'sequence_size']
)
Context = collections.namedtuple('Context', ['frames', 'cameras'])
Query = collections.namedtuple('Query', ['context', 'query_camera'])
TaskData = collections.namedtuple('TaskData', ['query', 'target'])
Query = collections.namedtuple('Query', ['query_camera', 'target'])
Input = collections.namedtuple('Input', ['context_frames', 'context_cameras', 'query_camera', 'target'])
TaskData = collections.namedtuple('TaskData', ['inputs', 'target'])


_DATASETS = dict(
Expand Down Expand Up @@ -87,181 +85,125 @@
_MODES = ('train', 'test')


def _get_dataset_files(dateset_info, mode, root):
"""Generates lists of files for a given dataset version."""
basepath = dateset_info.basepath
base = os.path.join(root, basepath, mode)
if mode == 'train':
num_files = dateset_info.train_size
else:
num_files = dateset_info.test_size
def _convert_frame_data(jpeg_data):
decoded_frames = tf.image.decode_jpeg(jpeg_data)

length = len(str(num_files))
template = '{:0%d}-of-{:0%d}.tfrecord' % (length, length)
return [os.path.join(base, template.format(i + 1, num_files))
for i in range(num_files)]
return tf.image.convert_image_dtype(decoded_frames, dtype=tf.float32)


def _convert_frame_data(jpeg_data):
decoded_frames = tf.image.decode_jpeg(jpeg_data)
return tf.image.convert_image_dtype(decoded_frames, dtype=tf.float32)


class DataReader(object):
"""Minimal queue based TFRecord reader.

You can use this reader to load the datasets used to train Generative Query
Networks (GQNs) in the 'Neural Scene Representation and Rendering' paper.
See README.md for a description of the datasets and an example of how to use
the reader.
"""

def __init__(self,
dataset,
context_size,
root,
mode='train',
# Optionally reshape frames
custom_frame_size=None,
# Queue params
num_threads=4,
capacity=256,
min_after_dequeue=128,
seed=None):
"""Instantiates a DataReader object and sets up queues for data reading.

Args:
dataset: string, one of ['jaco', 'mazes', 'rooms_ring_camera',
'rooms_free_camera_no_object_rotations',
'rooms_free_camera_with_object_rotations', 'shepard_metzler_5_parts',
'shepard_metzler_7_parts'].
context_size: integer, number of views to be used to assemble the context.
root: string, path to the root folder of the data.
mode: (optional) string, one of ['train', 'test'].
custom_frame_size: (optional) integer, required size of the returned
frames, defaults to None.
num_threads: (optional) integer, number of threads used to feed the reader
queues, defaults to 4.
capacity: (optional) integer, capacity of the underlying
RandomShuffleQueue, defualts to 256.
min_after_dequeue: (optional) integer, min_after_dequeue of the underlying
RandomShuffleQueue, defualts to 128.
seed: (optional) integer, seed for the random number generators used in
the reader.

Raises:
ValueError: if the required version does not exist; if the required mode
is not supported; if the requested context_size is bigger than the
maximum supported for the given dataset version.
"""
def _get_randomized_indices(context_size, sequence_size):
"""Generates randomized indices into a sequence of a specific length."""
if context_size is None:
maximum_context_size = min(sequence_size-1, 20)
context_size = tf.random.uniform([1], 1, maximum_context_size, dtype=tf.int32)
else:
context_size = tf.constant(context_size, shape=[1], dtype=tf.int32)
example_size = context_size + 1

indices = tf.range(0, sequence_size)
indices = tf.random.shuffle(indices)
indices = tf.slice(indices, begin=[0], size=example_size)

return indices, example_size


def data_reader(dataset,
root,
mode,
batch_size,
context_size=None,
custom_frame_size=None,
shuffle_buffer_size=256,
num_parallel_reads=4,
seed=None):

if dataset not in _DATASETS:
raise ValueError('Unrecognized dataset {} requested. Available datasets '
'are {}'.format(dataset, _DATASETS.keys()))
raise ValueError('Unrecognized dataset {} requested. Available datasets '
'are {}'.format(dataset, _DATASETS.keys()))

if mode not in _MODES:
raise ValueError('Unsupported mode {} requested. Supported modes '
'are {}'.format(mode, _MODES))

self._dataset_info = _DATASETS[dataset]

if context_size >= self._dataset_info.sequence_size:
raise ValueError(
'Maximum support context size for dataset {} is {}, but '
'was {}.'.format(
dataset, self._dataset_info.sequence_size-1, context_size))

self._context_size = context_size
# Number of views in the context + target view
self._example_size = context_size + 1
self._custom_frame_size = custom_frame_size

with tf.device('/cpu'):
file_names = _get_dataset_files(self._dataset_info, mode, root)
filename_queue = tf.train.string_input_producer(file_names, seed=seed)
reader = tf.TFRecordReader()

read_ops = [self._make_read_op(reader, filename_queue)
for _ in range(num_threads)]

dtypes = nest.map_structure(lambda x: x.dtype, read_ops[0])
shapes = nest.map_structure(lambda x: x.shape[1:], read_ops[0])

self._queue = tf.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=min_after_dequeue,
dtypes=dtypes,
shapes=shapes,
seed=seed)

enqueue_ops = [self._queue.enqueue_many(op) for op in read_ops]
tf.train.add_queue_runner(tf.train.QueueRunner(self._queue, enqueue_ops))

def read(self, batch_size):
"""Reads batch_size (query, target) pairs."""
frames, cameras = self._queue.dequeue_many(batch_size)
context_frames = frames[:, :-1]
context_cameras = cameras[:, :-1]
target = frames[:, -1]
query_camera = cameras[:, -1]
context = Context(cameras=context_cameras, frames=context_frames)
query = Query(context=context, query_camera=query_camera)
return TaskData(query=query, target=target)

def _make_read_op(self, reader, filename_queue):
"""Instantiates the ops used to read and parse the data into tensors."""
_, raw_data = reader.read_up_to(filename_queue, num_records=16)
raise ValueError('Unsupported mode {} requested. Supported modes '
'are {}'.format(mode, _MODES))

dataset_info = _DATASETS[dataset]

if context_size is not None and context_size >= dataset_info.sequence_size:
raise ValueError(
'Maximum support context size for dataset {} is {}, but '
'was {}.'.format(
dataset, dataset_info.sequence_size-1, context_size))

tf.random.set_seed(seed)

basepath = dataset_info.basepath
file_pattern = os.path.join(root, basepath, mode, '*.tfrecord')
files = tf.data.Dataset.list_files(file_pattern, shuffle=False)
raw_dataset = files.interleave(
tf.data.TFRecordDataset, cycle_length=num_parallel_reads,
num_parallel_calls=AUTOTUNE).repeat().shuffle(shuffle_buffer_size, seed=seed)

feature_map = {
'frames': tf.FixedLenFeature(
shape=self._dataset_info.sequence_size, dtype=tf.string),
'cameras': tf.FixedLenFeature(
shape=[self._dataset_info.sequence_size * _NUM_RAW_CAMERA_PARAMS],
dtype=tf.float32)
'frames': tf.io.FixedLenFeature(
shape=dataset_info.sequence_size, dtype=tf.string),
'cameras': tf.io.FixedLenFeature(
shape=[dataset_info.sequence_size * _NUM_RAW_CAMERA_PARAMS],
dtype=tf.float32)
}
example = tf.parse_example(raw_data, feature_map)
indices = self._get_randomized_indices()
frames = self._preprocess_frames(example, indices)
cameras = self._preprocess_cameras(example, indices)
return frames, cameras

def _get_randomized_indices(self):
"""Generates randomized indices into a sequence of a specific length."""
indices = tf.range(0, self._dataset_info.sequence_size)
indices = tf.random_shuffle(indices)
indices = tf.slice(indices, begin=[0], size=[self._example_size])
return indices

def _preprocess_frames(self, example, indices):
"""Instantiates the ops used to preprocess the frames data."""
frames = tf.concat(example['frames'], axis=0)
frames = tf.gather(frames, indices, axis=1)
frames = tf.map_fn(
_convert_frame_data, tf.reshape(frames, [-1]),
dtype=tf.float32, back_prop=False)
dataset_image_dimensions = tuple(
[self._dataset_info.frame_size] * 2 + [_NUM_CHANNELS])
frames = tf.reshape(
frames, (-1, self._example_size) + dataset_image_dimensions)
if (self._custom_frame_size and
self._custom_frame_size != self._dataset_info.frame_size):
frames = tf.reshape(frames, (-1,) + dataset_image_dimensions)
new_frame_dimensions = (self._custom_frame_size,) * 2 + (_NUM_CHANNELS,)
frames = tf.image.resize_bilinear(
frames, new_frame_dimensions[:2], align_corners=True)
frames = tf.reshape(
frames, (-1, self._example_size) + new_frame_dimensions)
return frames

def _preprocess_cameras(self, example, indices):
"""Instantiates the ops used to preprocess the cameras data."""
raw_pose_params = example['cameras']
raw_pose_params = tf.reshape(
raw_pose_params,
[-1, self._dataset_info.sequence_size, _NUM_RAW_CAMERA_PARAMS])
raw_pose_params = tf.gather(raw_pose_params, indices, axis=1)
pos = raw_pose_params[:, :, 0:3]
yaw = raw_pose_params[:, :, 3:4]
pitch = raw_pose_params[:, :, 4:5]
cameras = tf.concat(
[pos, tf.sin(yaw), tf.cos(yaw), tf.sin(pitch), tf.cos(pitch)], axis=2)
return cameras
def _parse_function(example):
return tf.io.parse_single_example(example, feature_map)

parsed_dataset = raw_dataset.map(_parse_function,
num_parallel_calls=AUTOTUNE).batch(batch_size)

def _preprocess_fn(example):
frames = example['frames']
raw_pose_params = example['cameras']

indices, example_size = _get_randomized_indices(context_size, dataset_info.sequence_size)

frames = tf.gather(frames, indices, axis=1)
frames = tf.map_fn(
_convert_frame_data, tf.reshape(frames, [-1]),
dtype=tf.float32, back_prop=False)
dataset_image_dimensions = tuple(
[dataset_info.frame_size] * 2 + [_NUM_CHANNELS])
frames = tf.reshape(
frames, (-1, example_size[0]) + dataset_image_dimensions)
if (custom_frame_size and
custom_frame_size != dataset_info.frame_size):
frames = tf.reshape(frames, (-1,) + dataset_image_dimensions)
new_frame_dimensions = (custom_frame_size,) * 2 + (_NUM_CHANNELS,)
frames = tf.image.resize(
frames, new_frame_dimensions[:2])
frames = tf.reshape(
frames, (-1, example_size[0]) + new_frame_dimensions)

raw_pose_params = tf.reshape(
raw_pose_params,
[-1, dataset_info.sequence_size, _NUM_RAW_CAMERA_PARAMS])
raw_pose_params = tf.gather(raw_pose_params, indices, axis=1)
pos = raw_pose_params[:, :, 0:3]
yaw = raw_pose_params[:, :, 3:4]
pitch = raw_pose_params[:, :, 4:5]
cameras = tf.concat(
[pos, tf.sin(yaw), tf.cos(yaw), tf.sin(pitch), tf.cos(pitch)], axis=2)

context_frames = frames[:, :-1]
context_cameras = cameras[:, :-1]
target = frames[:, -1]
query_camera = cameras[:, -1]
# context = Context(frames=context_frames, cameras=context_cameras)
query = Query(query_camera=query_camera, target=target)
inputs = Input(context_frames=context_frames,
context_cameras=context_cameras,
query_camera=query_camera,
target=target)


return TaskData(inputs=inputs, target=target)

preprocessed_dataset = parsed_dataset.map(_preprocess_fn,
num_parallel_calls=AUTOTUNE)

return preprocessed_dataset.prefetch(buffer_size=AUTOTUNE)