In [1]:
import tensorflow as tf

In [2]:
GCS_PATH = "gs://pets-tfrecords/pets-tfrecords"
BATCH_SIZE = 4
AUTO = tf.data.AUTOTUNE

In [3]:
def parse_tfr(proto):
    feature_description = {
        "image": tf.io.VarLenFeature(tf.float32),
        "image_shape": tf.io.VarLenFeature(tf.int64),
        "label": tf.io.VarLenFeature(tf.float32),
        "label_shape": tf.io.VarLenFeature(tf.int64),
    }
    rec = tf.io.parse_single_example(proto, feature_description)
    image_shape = tf.sparse.to_dense(rec["image_shape"])
    image = tf.reshape(tf.sparse.to_dense(rec["image"]), image_shape)
    label_shape = tf.sparse.to_dense(rec["label_shape"])
    label = tf.reshape(tf.sparse.to_dense(rec["label"]), label_shape)
    return {"pixel_values": image, "label": label}


def prepare_dataset(GCS_PATH=GCS_PATH, split="train", batch_size=BATCH_SIZE):
    if split not in ["train", "val"]:
        raise ValueError(
            "Invalid split provided. Supports splits are: `train` and `val`."
        )

    dataset = tf.data.TFRecordDataset(
        [filename for filename in tf.io.gfile.glob(f"{GCS_PATH}/{split}-*")],
        num_parallel_reads=AUTO,
    ).map(parse_tfr, num_parallel_calls=AUTO)

    if split == "train":
        dataset = dataset.shuffle(batch_size * 2)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)
    return dataset

## Check Validity of Dataset

In [4]:
train_dataset = prepare_dataset()
val_dataset = prepare_dataset(split="val")

In [5]:
for batch in train_dataset.take(1):
    print(batch["pixel_values"].shape, batch["label"].shape)

(4, 128, 128, 3) (4, 128, 128)


In [6]:
for batch in val_dataset.take(1):
    print(batch["pixel_values"].shape, batch["label"].shape)

(4, 128, 128, 3) (4, 128, 128)
