In [None]:
import tensorflow as tf
from tensorflow.keras import Input, layers, Model, Sequential
import tensorflow_datasets as tfds
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

import re
import sys

# Setup

In [None]:
def connect_to_tpu():

    try: # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
        print("Running on TPU  ", resolver.master())
    except ValueError: # detect GPUs
        resolver = None

    if resolver:
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        print("All devices: ", tf.config.list_logical_devices('TPU'))
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    else:
        strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
        #strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
        #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines
    
    return strategy


strategy = connect_to_tpu()

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
#Get google storage bucket

#set these credentials after initializing tpu
from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

# Then get the google bucket path
GCS_DS_PATH = KaggleDatasets().get_gcs_path() # you can list the bucket with "!gsutil ls $GCS_DS_PATH"

print(GCS_DS_PATH)

#Do not use this path to load tfrecords '../input/enpt10000/train.tfrecord'
#instead use tf.io.gfile.glob(GCS_DS_PATH+'/train.tfrecord')

#https://www.kaggle.com/philculliton/a-simple-tf-2-1-notebook
#very useful notebook on using TPUs

#https://www.tensorflow.org/datasets/gcs
#using tfds with TPUs

# Prepare flower datset

In [None]:
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
EPOCHS = 10
LEARNING_RATE = 3e-5 * strategy.num_replicas_in_sync
HEIGHT = 512
WIDTH = 512
CHANNELS = 3
N_CLASSES = 104
seed = 101

folder = f"/tfrecords-jpeg-{HEIGHT}x{WIDTH}"
train_files = tf.io.gfile.glob(GCS_DS_PATH + folder + '/train/*.tfrec')
test_files =tf.io.gfile.glob(GCS_DS_PATH + folder + '/test/*.tfrec')
val_files =tf.io.gfile.glob(GCS_DS_PATH + folder + '/val/*.tfrec')


# Datasets utility functions
AUTO = tf.data.experimental.AUTOTUNE # instructs the API to read from multiple files if available.

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [HEIGHT, WIDTH, 3])
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    if labeled:
        dataset = dataset.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
    else:
        dataset = dataset.map(read_unlabeled_tfrecord,  num_parallel_calls=AUTO)
    return dataset

def data_augment(image, label):
    p_spatial = tf.random.uniform([1], minval=0, maxval=1, dtype='float32', seed=seed)
    p_spatial2 = tf.random.uniform([1], minval=0, maxval=1, dtype='float32', seed=seed)
    p_pixel = tf.random.uniform([1], minval=0, maxval=1, dtype='float32', seed=seed)
    p_crop = tf.random.uniform([1], minval=0, maxval=1, dtype='float32', seed=seed)
    
    ### Spatial-level transforms
    if p_spatial >= .2: # flips
        image = tf.image.random_flip_left_right(image, seed=seed)
        image = tf.image.random_flip_up_down(image, seed=seed)
        
    if p_crop >= .7: # crops
        if p_crop >= .95:
            image = tf.image.random_crop(image, size=[int(HEIGHT*.6), int(WIDTH*.6), CHANNELS], seed=seed)
        elif p_crop >= .85:
            image = tf.image.random_crop(image, size=[int(HEIGHT*.7), int(WIDTH*.7), CHANNELS], seed=seed)
        elif p_crop >= .8:
            image = tf.image.random_crop(image, size=[int(HEIGHT*.8), int(WIDTH*.8), CHANNELS], seed=seed)
        else:
            image = tf.image.random_crop(image, size=[int(HEIGHT*.9), int(WIDTH*.9), CHANNELS], seed=seed)
        image = tf.image.resize(image, size=[HEIGHT, WIDTH])
    
    ## Pixel-level transforms
    if p_pixel >= .4: # pixel transformations
        if p_pixel >= .85:
            image = tf.image.random_saturation(image, lower=0, upper=2, seed=seed)
        elif p_pixel >= .65:
            image = tf.image.random_contrast(image, lower=.8, upper=2, seed=seed)
        elif p_pixel >= .5:
            image = tf.image.random_brightness(image, max_delta=.2, seed=seed)
        else:
            image = tf.image.adjust_gamma(image, gamma=.6)

    return image, label

def get_training_dataset(filenames):
    dataset = load_dataset(filenames, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.cache()
    dataset = dataset.shuffle(1000000)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(filenames):
    dataset = load_dataset(filenames, labeled=True)
    dataset = dataset.cache()
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_test_dataset(filenames):
    dataset = load_dataset(filenames, labeled=True)
    dataset = dataset.cache()
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def int_div_round_up(a, b):
    return (a + b - 1) // b

In [None]:
# Train data
NUM_TRAINING_IMAGES = count_data_items(train_files)
train_ds = get_training_dataset(train_files)

# Val data
NUM_TEST_IMAGES = count_data_items(test_files)
val_ds = get_test_dataset(val_files)

# Test data
NUM_TEST_IMAGES = count_data_items(test_files)
test_ds = get_test_dataset(test_files)

train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
val_dist_ds = strategy.experimental_distribute_dataset(val_ds)
test_dist_ds = strategy.experimental_distribute_dataset(test_ds)

In [None]:
#May be slow since we are caching it the first time


tfds.core.benchmark(train_dist_ds)
tfds.core.benchmark(val_dist_ds)
#tfds.core.benchmark(test_dist_ds)


# Strategy, Training loops & Validation functions

In [None]:
#model, optimizer, metrics must be inside strategy scope
with strategy.scope():
    
    base_model = tf.keras.applications.VGG16(include_top=False, weights='imagenet',input_shape=(512,512,3))
    
    for layer in base_model.layers[:-3]:
        layer.trainable = False
    
    model = Sequential([base_model,
                        layers.GlobalMaxPool2D(),
                        layers.Dense(N_CLASSES)])
    
    optimizer = tf.keras.optimizers.Adam(0.01)
    
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def loss_function(labels, preds):
        per_example_loss = loss_object(labels, preds)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size= BATCH_SIZE)
    
    train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
    train_acc = tf.keras.metrics.SparseCategoricalAccuracy('training_acc', dtype=tf.float32)
    test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
    test_acc = tf.keras.metrics.SparseCategoricalAccuracy('test_acc', dtype=tf.float32)
    
    #checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    


def forward_pass(ds_chunk):
    inputs, labels = ds_chunk
    preds = model(inputs)
    return preds, labels

@tf.function
def multiple_dist_train_steps(dist_iter, steps):
    def _train_step(ds_chunk):
        with tf.GradientTape() as tape:
            preds, labels = forward_pass(ds_chunk)
            loss_val = loss_function(labels, preds)
        gradients = tape.gradient(loss_val, model.trainable_variables) 
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                     experimental_aggregate_gradients=True)
        train_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        train_acc.update_state(labels, preds)
    
    for _ in tf.range(steps):
        optional_data = dist_iter.get_next_as_optional()
        if not optional_data.has_value():
            break
        strategy.run(_train_step, args=(optional_data.get_value(),))
        # tf.print(strategy.experimental_local_results(per_replica_results))

@tf.function
def dist_train_epoch(ds):
     #https://www.tensorflow.org/tutorials/distribute/custom_training#iterating_inside_a_tffunction
    def _train_step(ds_chunk):
        with tf.GradientTape() as tape:
            preds, labels = forward_pass(ds_chunk)
            loss_val = loss_function(labels, preds)
        gradients = tape.gradient(loss_val, model.trainable_variables) 
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                 experimental_aggregate_gradients=True)
      
        train_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        train_acc.update_state(labels, preds)
    for chunk in ds:
        strategy.run(_train_step, args = (chunk,))

@tf.function
def multiple_dist_test_steps(dist_iter, steps):
    def _test_step(ds_chunk):
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_function(labels, preds)
        test_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        test_acc.update_state(labels, preds)

    for _ in tf.range(steps):
        optional_data = dist_iter.get_next_as_optional()
        if not optional_data.has_value():
            break
        strategy.run(_test_step, args=(optional_data.get_value(),))


@tf.function
def dist_test_epoch(ds):
     #https://www.tensorflow.org/tutorials/distribute/custom_training#iterating_inside_a_tffunction
    def _test_step(ds_chunk):   
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_function(labels, preds)
        test_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        test_acc.update_state(labels, preds)

    for chunk in ds:
        strategy.run(_test_step, args = (chunk,))



In [None]:
iter(train_dist_ds).get_next_as_optional()

In [None]:
from time import time

# start = int(round(time() * 1000))
# multiple_dist_train_steps(iter(train_dist_ds), 100)
# end_ = int(round(time() * 1000)) - start
# print(end_)

start = int(round(time() * 1000))
dist_train_epoch(train_dist_ds)
end_ = int(round(time() * 1000)) - start
print(end_)

In [None]:
train_loss.reset_states()
train_acc.reset_states()
test_loss.reset_states()
test_acc.reset_states()

epochs = 10

# batches_in_train_ds = sum([1 for i in train_dist_ds])
# batches_in_val_ds = sum([1 for i in val_dist_ds])
# print(batches_in_train_ds)
# print(batches_in_val_ds)

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(epochs):
    dist_train_epoch(train_dist_ds)
    dist_test_epoch(val_dist_ds)
    
    train_losses.append(train_loss.result())
    train_accuracies.append(train_acc.result())
    val_losses.append(test_loss.result())
    val_accuracies.append(test_acc.result())
    
    sys.stdout.write(f"\rEpoch {epoch+1}  train_loss: {train_loss.result():.3f}  train_acc: {train_acc.result():.3f}    val_loss: {test_loss.result():.3f}    val_acc: {test_acc.result():.3f}")
    
    train_loss.reset_states()
    train_acc.reset_states()
    test_loss.reset_states()
    test_acc.reset_states()

In [None]:
def round_fn(x):
    return list(map(lambda item: round(item.numpy(),3), x))



import matplotlib.pyplot as plt

fig, (ax1,ax2) = plt.subplots(1,2, figsize = (10,5))

x = list(range(1,epochs+1))

ax1.plot(x, round_fn(train_losses),'--ro')
ax1.plot(x, round_fn(val_losses),'--bo')

ax2.plot(x, round_fn(train_accuracies),'--r+')
ax2.plot(x, round_fn(val_accuracies),'--b+')
