In [None]:
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
import tensorflow_cloud as tfc
import time
import os

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = "gs://mchrestkha-demo-env-ml-examples/catsdogs/tfrecords/tfrecorder-20200930-193548-to-tfr"
BATCH_SIZE = 5
IMAGE_SIZE = [150, 150]

In [None]:
TRAINING_FILENAMES=tf.io.gfile.glob(GCS_PATH + "/train*.tfrecord.gz")
VALID_FILENAMES=tf.io.gfile.glob(GCS_PATH + "/validation*.tfrecord.gz")

In [None]:
print("Train TFRecord Files:", len(TRAINING_FILENAMES))
print("Validation TFRecord Files:", len(VALID_FILENAMES))

In [None]:
def read_tfrecord(example):
    tfr_format = {
            "image": tf.io.FixedLenFeature([], tf.string),
            "image_channels": tf.io.FixedLenFeature([], tf.int64),
            "image_height": tf.io.FixedLenFeature([], tf.int64),
            "image_name": tf.io.FixedLenFeature([], tf.string),
            "image_width": tf.io.FixedLenFeature([], tf.int64),
            "label": tf.io.FixedLenFeature([], tf.int64),
            "split": tf.io.FixedLenFeature([], tf.string),
        }
    image_features= tf.io.parse_single_example(example, tfr_format)
    image_channels=image_features['image_channels']
    image_width=image_features['image_width']
    image_height=image_features['image_height']
    label=image_features['label']
    image_b64_bytes=image_features['image']
    image_decoded=tf.io.decode_base64(image_b64_bytes)
    image_raw = tf.io.decode_raw(image_decoded, out_type=tf.uint8)
    image = tf.reshape(image_raw, tf.stack([image_height, image_width, image_channels]))
    image_resized = tf.cast(tf.image.resize(image, size=[*IMAGE_SIZE]),tf.uint8)
    return image_resized, label

In [None]:
def get_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames=filenames, compression_type='GZIP') 
    dataset = dataset.map(read_tfrecord)
    dataset = dataset.shuffle(200)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

In [None]:
train_dataset = get_dataset(TRAINING_FILENAMES)
valid_dataset = get_dataset(VALID_FILENAMES)
# image_batch, label_batch = next(iter(train_dataset))
# image_batch[0].numpy()
# for n in range(2):
#         plt.imshow(image_batch[n])       

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(150, 150, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

#model.summary()
model.compile(loss='binary_crossentropy',
              optimizer=RMSprop(lr=1e-4),
              metrics=['accuracy'])

In [None]:
model.fit(
    train_dataset,
    epochs=10,
    validation_data=valid_dataset,
    verbose=2)

In [None]:
model.save(time.strftime("gs://mchrestkha-demo-env-ml-examples/catsdogs/models/model_%Y%m%d_%H%M%S"))