<a href="https://colab.research.google.com/github/ivelin/gui2refexp/blob/main/dataset/tfrecord_export.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#@title Checkout source files from github repo
![[ -d "gui2refexp" ]] || git clone https://github.com/ivelin/gui2refexp.git 

!cd gui2refexp && git pull


Cloning into 'gui2refexp'...
remote: Enumerating objects: 131, done.[K
remote: Counting objects: 100% (131/131), done.[K
remote: Compressing objects: 100% (98/98), done.[K
remote: Total 131 (delta 44), reused 106 (delta 29), pack-reused 0[K
Receiving objects: 100% (131/131), 8.58 MiB | 11.21 MiB/s, done.
Resolving deltas: 100% (44/44), done.
Already up to date.


##### Copyright 2019 The TensorFlow Authors.

In [4]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

The TFRecord format is a simple format for storing a sequence of binary records.

[Protocol buffers](https://developers.google.com/protocol-buffers/) are a cross-platform, cross-language library for efficient serialization of structured data.

Protocol messages are defined by `.proto` files, these are often the easiest way to understand a message type.

The `tf.train.Example` message (or protobuf) is a flexible message type that represents a `{"string": value}` mapping. It is designed for use with TensorFlow and is used throughout the higher-level APIs such as [TFX](https://www.tensorflow.org/tfx/).

This notebook demonstrates how to create, parse, and use the `tf.train.Example` message, and then serialize, write, and read `tf.train.Example` messages to and from `.tfrecord` files.

Note: While useful, these structures are optional. There is no need to convert existing code to use TFRecords, unless you are [using tf.data](https://www.tensorflow.org/guide/data) and reading data is still the bottleneck to training. You can refer to [Better performance with the tf.data API](https://www.tensorflow.org/guide/data_performance) for dataset performance tips.

Note: In general, you should shard your data across multiple files so that you can parallelize I/O (within a single host or across multiple hosts). The rule of thumb is to have at least 10 times as many files as there will be hosts reading data. At the same time, each file should be large enough (at least 10 MB+ and ideally 100 MB+) so that you can benefit from I/O prefetching. For example, say you have `X` GB of data and you plan to train on up to `N` hosts. Ideally, you should shard the data to ~`10*N` files, as long as ~`X/(10*N)` is 10 MB+ (and ideally 100 MB+). If it is less than that, you might need to create fewer shards to trade off parallelism benefits and I/O prefetching benefits.

## Setup

In [3]:
import tensorflow as tf

import numpy as np
import IPython.display as display

## TFRecords format details

A TFRecord file contains a sequence of records. The file can only be read sequentially.

Each record contains a byte-string, for the data-payload, plus the data-length, and  CRC-32C ([32-bit CRC](https://en.wikipedia.org/wiki/Cyclic_redundancy_check#CRC-32_algorithm) using the [Castagnoli polynomial](https://en.wikipedia.org/wiki/Cyclic_redundancy_check#Standards_and_common_use)) hashes for integrity checking.

Each record is stored in the following formats:

    uint64 length
    uint32 masked_crc32_of_length
    byte   data[length]
    uint32 masked_crc32_of_data

The records are concatenated together to produce the file. CRCs are
[described here](https://en.wikipedia.org/wiki/Cyclic_redundancy_check), and
the mask of a CRC is:

    masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul


## TFRecord files in Python

The `tf.io` module also contains pure-Python functions for reading and writing TFRecord files.

### Reading a TFRecord file

These serialized tensors can be easily parsed using `tf.train.Example.ParseFromString`:

In [8]:
filenames = ["gui2refexp/dataset/pix2struct_data_data_refexp_test.tfrecord"]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset

<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

In [9]:
record_count = 0
for raw_record in raw_dataset:
  record_count += 1

print(f'Total records in the dataset: {record_count}')

for raw_record in raw_dataset.take(3):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print('---------------------------------------------')
  for key, feature in example.features.feature.items():
    if key in ["image/ref_exp/text", "image/id", "image/object/num", "image/view_hierarchy/description", "image/view_hierarchy/text"] :
      print(key, feature)
    else:
          print(key)


Total records in the dataset: 616
---------------------------------------------
image/object/num float_list {
  value: 32.0
}

image/id bytes_list {
  value: "61654"
}

image/object/bbox/xmin
image/view_hierarchy/bbox/ymin
image/object/bbox/ymin
image/ref_exp/text bytes_list {
  value: "select the left side bottom image"
}

image/view_hierarchy/bbox/xmax
image/view_hierarchy/class/name
image/view_hierarchy/id/name
image/view_hierarchy/text bytes_list {
  value: ""
  value: "tv shows"
  value: ""
  value: ""
  value: "popular"
  value: "recently added"
  value: ""
  value: "i dont watch tv"
  value: " views"
  value: ""
  value: "ranjish"
  value: " views"
  value: ""
  value: ""
  value: ""
  value: "are ho ja re gender"
  value: ""
  value: ""
  value: "john smith"
  value: "appcrawler4 gmail com"
  value: ""
  value: ""
  value: ""
  value: "home"
  value: ""
  value: "live tv"
  value: ""
  value: "catch up tv"
  value: ""
  value: "tv shows"
  value: ""
  value: "yuppflix"
  value:

That returns a `tf.train.Example` proto which is dificult to use as is, but it's fundamentally a representation of a:

```
Dict[str,
     Union[List[float],
           List[int],
           List[str]]]
```

The following code manually converts the `Example` to a dictionary of NumPy arrays, without using TensorFlow Ops. Refer to [the PROTO file](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.proto) for details.

In [None]:
result = {}
# example.features.feature is the dictionary
for key, feature in example.features.feature.items():
  # The values are the Feature objects which contain a `kind` which contains:
  # one of three fields: bytes_list, float_list, int64_list

  kind = feature.WhichOneof('kind')
  result[key] = np.array(getattr(feature, kind).value)

result

## Walkthrough: Reading and writing image data

This is an end-to-end example of how to read and write image data using TFRecords. Using an image as input data, you will write the data as a TFRecord file, then read the file back and display the image.

This can be useful if, for example, you want to use several models on the same input dataset. Instead of storing the image data raw, it can be preprocessed into the TFRecords format, and that can be used in all further processing and modelling.

First, let's download [this image](https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg) of a cat in the snow and [this photo](https://upload.wikimedia.org/wikipedia/commons/f/fe/New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg) of the Williamsburg Bridge, NYC under construction.

### Fetch the images

In [None]:
cat_in_snow  = tf.keras.utils.get_file(
    '320px-Felis_catus-cat_on_snow.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')

williamsburg_bridge = tf.keras.utils.get_file(
    '194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')

In [None]:
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

In [None]:
display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

### Write the TFRecord file

As before, encode the features as types compatible with `tf.train.Example`. This stores the raw image string feature, as well as the height, width, depth, and arbitrary `label` feature. The latter is used when you write the file to distinguish between the cat image and the bridge image. Use `0` for the cat image, and `1` for the bridge image:

In [None]:
image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}

In [None]:
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.io.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')

Notice that all of the features are now stored in the `tf.train.Example` message. Next, functionalize the code above and write the example messages to a file named `images.tfrecords`:

In [None]:
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())

In [None]:
!du -sh {record_file}

### Read the TFRecord file

You now have the file—`images.tfrecords`—and can now iterate over the records in it to read back what you wrote. Given that in this example you will only reproduce the image, the only feature you will need is the raw image string. Extract it using the getters described above, namely `example.features.feature['image_raw'].bytes_list.value[0]`. You can also use the labels to determine which record is the cat and which one is the bridge:

In [None]:
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset

Recover the images from the TFRecord file:

In [None]:
for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))