In [None]:
project_id = 'stairnet-unlabeled'
# !gcloud auth login
!gcloud config set project {project_id}

from google.colab import drive, auth
drive.mount('/content/drive')
auth.authenticate_user()

Updated property [core/project].
Mounted at /content/drive


## Load tfrecords data

In [None]:
GCS_PATH = "gs://bucket_name/tfrecords_out_folder"

In [None]:
UNLABELED_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/_training_01of01.tfrecord')

In [None]:
# List all the TFRecords and create a dataset from it
filenames_dataset = tf.data.Dataset.from_tensor_slices(UNLABELED_FILENAMES)

In [None]:
# Create a description of the features.
feature_description = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True)
}

@tf.function
def _parse_function(example_proto):
    # Parse the input `tf.Example` proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, feature_description)

# Preprocess Image
@tf.function
def process_image_tfrecord(record):  
    image = tf.io.decode_jpeg(record['image'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    # image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    image = tf.reshape(image, [224, 224, 3])
    # image = tf.image.random_crop(value=image, size=(224, 224, 3))
    label = record['label']
    
    return image, label

# Create a Dataset composed of TFRecords (paths to bucket)
@tf.function
def get_tfrecord(filename):
    return tf.data.TFRecordDataset(filename, num_parallel_reads=AUTO)

def build_dataset(dataset):

    dataset = dataset.interleave(get_tfrecord, num_parallel_calls=AUTO)
    
    # Transformation: IO Intensive 
    dataset = dataset.map(_parse_function, num_parallel_calls=AUTO)

    # Transformation: CPU Intensive
    dataset = dataset.map(process_image_tfrecord, num_parallel_calls=AUTO)

    if repeat:
        dataset = dataset.repeat()

    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.shuffle(30000) # sample_size // 10
    dataset = dataset.cache()

    # Pipeline next iteration
    dataset = dataset.prefetch(buffer_size=AUTO)
    
    return dataset

In [None]:
ds = build_dataset(filenames_dataset)