# Keyframe Extraction
Extracts keyframes and goal images from all tfrecords in a GEECO dataset.

In [None]:
'''imports'''
import os
import json
import pprint
import re
from timeit import default_timer as timer

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import scipy.misc
import tensorflow as tf
from tqdm import tqdm

from data.geeco_gym import PickAndPlaceEncodingV4, PickAndPlaceMetaV4
from utils.plotting import create_image_grid

In [None]:
'''path setup'''
root_path = os.environ['GEECO_ROOT']
# dataset_dir = os.path.join(root_path, 'data', 'gym-push-pad1-cube1-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-push-pad1-cube2-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-push-pad2-cube1-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-push-pad2-cube2-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-pick-pad1-cube1-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-pick-pad1-cube2-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-pick-pad2-cube1-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-pick-pad2-cube2-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-pick-pad2-cube2-clutter4-v4')
# dataset_dir = os.path.join(root_path, 'data', 'gym-pick-pad2-cube2-clutter12-v4')
print(dataset_dir)

In [None]:
'''collect files'''
# tfrecord paths
tfrecord_dir = os.path.join(dataset_dir, 'data')
tfrecord_files = [f for f in os.listdir(tfrecord_dir) if f.endswith('.tfrecord.zlib')]
tfrecord_files.sort()
tfrecord_paths = [os.path.join(tfrecord_dir, f) for f in tfrecord_files]
# meta information
meta_info_path = os.path.join(dataset_dir, 'meta', 'meta_info.json')
with open(meta_info_path, 'r') as fp:
    meta_info_dict = json.load(fp)
meta = PickAndPlaceMetaV4(**meta_info_dict)
# quick check output
print(len(tfrecord_files))
pprint.pprint(tfrecord_files[:10])
pprint.pprint(meta_info_dict)

In [None]:
'''parsing function'''
def _parse_record_v4(proto_example, meta):
    encoding = PickAndPlaceEncodingV4(meta)
    context_decoder, sequence_decoder = encoding.decode()
    # parse proto example
    context_data, sequence_data = tf.parse_single_sequence_example(
        serialized=proto_example,
        context_features=context_decoder,
        sequence_features=sequence_decoder)
    parsed_example = {}
    # parsed_example.update(context_data)  # context == meta here!
    parsed_example.update(sequence_data)
    # reshape data fields
    rgb = parsed_example['rgb']
    parsed_example['rgb'] = tf.reshape(rgb, [-1, meta.img_height, meta.img_width, 3])
    depth = parsed_example['depth']
    parsed_example['depth'] = tf.reshape(depth, [-1, meta.img_height, meta.img_width, 1])
    # normalize data
    parsed_example['rgb'] /= 255.0  # RGB recorded as uint8 [0 .. 255]
    return parsed_example

In [None]:
'''build dataset graph'''
num_threads = 4
# tfrecord dataset from sorted paths
dataset = tf.data.TFRecordDataset(
    filenames=tfrecord_paths,
    compression_type='ZLIB',
    num_parallel_reads=num_threads)
dataset = dataset.map(
    lambda proto_example: _parse_record_v4(proto_example, meta),
    num_parallel_calls=num_threads)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()

In [None]:
'''prepare output directories'''
rgb_targets_dir = os.path.join(dataset_dir, 'images', 'targets', 'rgb')
depth_targets_dir = os.path.join(dataset_dir, 'images', 'targets', 'depth')
rgb_keyframes_dir = os.path.join(dataset_dir, 'images', 'keyframes', 'rgb')
depth_keyframes_dir = os.path.join(dataset_dir, 'images', 'keyframes', 'depth')
for out_dir in [rgb_targets_dir, depth_targets_dir, rgb_keyframes_dir, depth_keyframes_dir]:
    os.makedirs(out_dir, exist_ok=True)

In [None]:
'''session setup'''
try:
    sess.close()
except:
    pass
sess = tf.InteractiveSession()

In [None]:
'''loop over tfrecord_paths and extract targets from corresponding records'''
for p in tqdm(tfrecord_paths):
    d = sess.run(data)
    filename = os.path.basename(p).split('.')[0]
    # extract RGB target and write to file
    rgb_target = np.squeeze(d['rgb'][-1])
    rgb_target_path = os.path.join(rgb_targets_dir, filename + '.png')
    # scipy.misc.imsave(rgb_target_path, rgb_target)
    Image.fromarray((rgb_target * 255).astype(np.uint8)).save(rgb_target_path)
    # rgb_import = scipy.misc.imread(rgb_target_path) / 255.0
    rgb_import = np.array(Image.open(rgb_target_path), dtype=np.float32) / 255.0
    try:
        assert np.allclose(rgb_target, rgb_import)
    except AssertionError as err:
        print(">>> Faulty RGB export for %s" % filename)
    # extract depth target and write to file
    depth_target = np.squeeze(d['depth'][-1])
    depth_target_path = os.path.join(depth_targets_dir, filename + '.npy')
    np.save(depth_target_path, depth_target)
    depth_import = np.load(depth_target_path)
    try:
        assert np.allclose(depth_target, depth_import)
    except AssertionError as err:
        print(">>> Faulty depth export for %s" % filename)
    # look up corresponding keyframe file; skip this part, if it does not exist
    record_id = re.search(r'\d+', filename).group(0)
    keyframe_filename = 'key_frames_%s.json' % (record_id, )
    keyframe_file = os.path.join(dataset_dir, 'data', keyframe_filename)
    if os.path.exists(keyframe_file):  # extract keyframes specified
        with open(keyframe_file) as fp:
            keyframe_dict = json.load(fp)
        keyframe_indices = keyframe_dict['key_frames']
        for key_idx in keyframe_indices:
#             key_idx = np.min([key_idx+1, len(d['rgb'])-1])  # adjusting idx by +1 to ensure that things have settled
            key_idx = np.min([key_idx, len(d['rgb'])-1])  # adjusting idx by +1 to ensure that things have settled
            # extract RGB target and write to file
            rgb_target = np.squeeze(d['rgb'][key_idx])
            rgb_target_path = os.path.join(rgb_keyframes_dir, filename + '_%04d' % key_idx + '.png')
            # scipy.misc.imsave(rgb_target_path, rgb_target)
            Image.fromarray((rgb_target * 255).astype(np.uint8)).save(rgb_target_path)
            # rgb_import = scipy.misc.imread(rgb_target_path) / 255.0
            rgb_import = np.array(Image.open(rgb_target_path), dtype=np.float32) / 255.0
            try:
                assert np.allclose(rgb_target, rgb_import)
            except AssertionError as err:
                print(">>> Faulty RGB export for %s" % filename)
            # extract depth target and write to file
            depth_target = np.squeeze(d['depth'][key_idx])
            depth_target_path = os.path.join(depth_keyframes_dir, filename + '_%04d' % key_idx + '.npy')
            np.save(depth_target_path, depth_target)
            depth_import = np.load(depth_target_path)
            try:
                assert np.allclose(depth_target, depth_import)
            except AssertionError as err:
                print(">>> Faulty depth export for %s" % filename)
    else:  # no keyframes specified, move on
        continue