In [1]:
import tensorflow as tf



In [2]:
GCS_PATH = "gs://sidewalks-tfx-hf/sidewalks-tfrecords"
BATCH_SIZE = 4
AUTO = tf.data.AUTOTUNE

In [3]:
def parse_tfr(proto):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "label": tf.io.FixedLenFeature([], tf.string)
    }
    rec = tf.io.parse_single_example(proto, feature_description)

    image = tf.io.parse_tensor(rec["image"], tf.float32)
    label = tf.io.parse_tensor(rec["label"], tf.float32)
    return {"pixel_values": image, "label": label}


def prepare_dataset(split="train", batch_size=BATCH_SIZE):
    if split not in ["train", "val"]:
        raise ValueError(
            "Invalid split provided. Supports splits are: `train` and `val`."
        )

    dataset = tf.data.TFRecordDataset(
        [filename for filename in tf.io.gfile.glob(f"{GCS_PATH}/{split}-*")],
        num_parallel_reads=AUTO,
    ).map(parse_tfr, num_parallel_calls=AUTO)

    if split == "train":
        dataset = dataset.shuffle(batch_size * 2)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)
    return dataset

In [4]:
train_dataset = prepare_dataset()
val_dataset = prepare_dataset(split="val")

In [5]:
for batch in train_dataset.take(1):
    print(batch["pixel_values"].shape, batch["label"].shape)

(4, 512, 512, 3) (4, 512, 512)


In [6]:
for batch in val_dataset.take(1):
    print(batch["pixel_values"].shape, batch["label"].shape)

(4, 512, 512, 3) (4, 512, 512)
