## Read and parse TFRecord

In [1]:
# Import modules and this file should be outside learning_to_simulate code folder
import functools
import os
import json
import tensorflow.compat.v1 as tf
from learning_to_simulate import reading_utils
import pickle
import numpy as np

In [2]:
# Set datapath and validation set
data_path = './datasets/WaterDropSample'
filename = 'valid.tfrecord'

In [3]:
# Read metadata
def _read_metadata(data_path):
    with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
        return json.loads(fp.read())

# Fetch metadata
metadata = _read_metadata(data_path)

# View records
ds = tf.data.TFRecordDataset([os.path.join(data_path, filename)])
ds = ds.map(functools.partial(
    reading_utils.parse_serialized_simulation_example, metadata=metadata))

context:  {'particle_type': <tf.Tensor 'Reshape_1:0' shape=(None,) dtype=int64>, 'key': <tf.Tensor 'ParseSingleSequenceExample/ParseSequenceExample/ParseSequenceExampleV2:3' shape=() dtype=int64>}
features:  {'position': <tf.Tensor 'Reshape:0' shape=(1001, None, 2) dtype=float32>}


In [4]:
# Print original dataset
ds

<DatasetV1Adapter shapes: ({particle_type: (None,), key: ()}, {position: (1001, None, 2)}), types: ({particle_type: tf.int64, key: tf.int64}, {position: tf.float32})>

In [5]:
# Fetch and inspect the first element
# 482 particles of type 5 (Water)
# 1001 time slices and position data (x, y)
next(iter(ds))

({'particle_type': <tf.Tensor: shape=(482,), dtype=int64, numpy=
  array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
         5, 5, 5, 5, 5, 5, 5, 5, 5,

In [6]:
# Convert to list
lds = list(ds)

In [7]:
# Extract array information
ptypes = lds[0][0]['particle_type'].numpy()
key = lds[0][0]['key'].numpy()
positions = lds[0][1]['position'].numpy()
ptypes

array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,

In [8]:
# Write a sequence example
seq = tf.train.SequenceExample(
        context=tf.train.Features(feature={
            'particle_type': tf.train.Feature(
                int64_list=tf.train.Int64List(value=ptypes)
            )
        }),
        feature_lists=tf.train.FeatureLists(feature_list={
            'position': tf.train.FeatureList(feature=[
                tf.train.Feature(float_list=tf.train.FloatList(value=positions.flatten()))
            ])
        })
    )

In [9]:
writer = tf.python_io.TFRecordWriter('test.tfrecord')
writer.write(seq.SerializeToString())
writer.close()

In [10]:
# Set datapath and validation set
# View records
dt = tf.data.TFRecordDataset(['./test.tfrecord'])
dt = dt.map(functools.partial(
    reading_utils.parse_serialized_simulation_example, metadata=metadata))

context:  {'particle_type': <tf.Tensor 'Reshape_1:0' shape=(None,) dtype=int64>, 'key': <tf.Tensor 'ParseSingleSequenceExample/ParseSequenceExample/ParseSequenceExampleV2:3' shape=() dtype=int64>}
features:  {'position': <tf.Tensor 'Reshape:0' shape=(1001, None, 2) dtype=float32>}


In [11]:
# Print the original dataset
ds

<DatasetV1Adapter shapes: ({particle_type: (None,), key: ()}, {position: (1001, None, 2)}), types: ({particle_type: tf.int64, key: tf.int64}, {position: tf.float32})>

In [12]:
# Print the created dateset
dt

<DatasetV1Adapter shapes: ({particle_type: (None,), key: ()}, {position: (1001, None, 2)}), types: ({particle_type: tf.int64, key: tf.int64}, {position: tf.float32})>

In [13]:
next(iter(dt))

InvalidArgumentError: Name: <unknown>, Context feature: particle_type.  Data types don't match. Expected type: string
	 [[{{node ParseSingleSequenceExample/ParseSequenceExample/ParseSequenceExampleV2}}]] [Op:IteratorGetNext]