In [None]:
import numpy as np 
import pandas as pd 


import tensorflow as tf
from tensorflow.keras import layers 
from tensorflow import keras
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

try:
    from kaggle_datasets import KaggleDatasets
    dataset_gcs = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
    print('got GCS path via KaggleDatasets .get_gcs_path method')
except ModuleNotFoundError:
    #hardcode path while testing locally
    dataset_gcs = 'gs://kds-e118bcdb309cf88b7f9e4a96ee84997123a5781b886180ffc13d3fc9'

In [None]:
tf.__version__

In [None]:
params = {
    'batch_size' : 128,
    'img_size' : [512, 512], #length and width will be equal
    'epochs': 400
}

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
train_df = pd.read_csv(dataset_gcs + '/train.csv')
train_df.groupby('label').count()


In [None]:
def decode_image_train(tfrec):
    '''
    ***update documentation 
    function to decode an image from tfrecord
    
    args:
        tfrec: tfrecord, single record of training/validation data
    
    returns:
        decoded_image: tensor, converted image from tfrecord
        img_name: tensor, string, Id of the decoded image
    
    '''
    
    features_dictionary = {
        'image': tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.int64)
        }
    features = tf.io.parse_single_example(tfrec, features_dictionary)
    decoded_image = tf.io.decode_jpeg(features['image'], 3)
#     decoded_image = tf.image.resize(decoded_image, params['img_size'])
#     img_name = features['image_name'] #I dont think I will need this
    target = features['target']
    
    return decoded_image, target #,img_name


In [None]:
def normalize_image(decoded_image, label):
    '''
    function to convert an image tensor values from 0 to 255 
    -> -1 to 1
    to be used when dealing with tfrecords containing labels
    
    args:
        decoded_image: tensor that is an image with values from 0 to 255
        label: tensor, target label
    
    returns: 
        image_tensor: tensor that is an image with values from -1 to 1
        label, same as input
    
    '''
    
    #add dim at the zero axis Shape will be from (x, y, z) -> (None, x, y, z)
    image_tensor = tf.expand_dims(decoded_image, 0)
    #undo the above line -- this is needed due to TF not allowing a filtered tensor py_function
    image_tensor = tf.gather(image_tensor, 0)

    #convert tensor values to between -1 and 1 (0 to 255 -> -1 to 1)
    image_tensor = (tf.cast(image_tensor, tf.float32) - 127.5) / 127.5

    return image_tensor, label

In [None]:
def random_flip(image, label):
    '''
    function to randomly flip images on the x and/or y axis
    
    args:
        image: tensor, an image
        label: tensor, target label
    
    returns: 
        image: tensor, same as input, but possibly flipped on x and/or y axis
        label, tensor, same as input
    '''
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    return image, label


In [None]:
def one_hot(image, label):
    '''
    function to one hot encode label
    
    args:
        image: tensor, an image
        label: tensor, target label
        
    returns:
        image: tensor, same as input
        label: tensor, one hot encoded with a depth of 5
    '''
    label = tf.one_hot(label, 5)
    return image, label

In [None]:
def get_train_ds(tfrecords, batch_size):
    '''
    function to create a ds pipeline from tfrecord files
    
    args:
        tfrecords: list, tfrecord file paths
        batch_size: int, batch size for number of records to pass into
            model at a time
    returns:
        ds: tensorflow input pipeline with images and labels
    '''
    ds = (tf.data.TFRecordDataset(filenames=[tfrecords],
                                 num_parallel_reads=tf.data.experimental.AUTOTUNE).
          cache(). #need to remove cache while not usnig TPUs
          map(decode_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE).
          map(normalize_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).
          map(random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE).
          map(one_hot, num_parallel_calls=tf.data.experimental.AUTOTUNE).
          repeat().
          shuffle(512).
          batch(batch_size,
               drop_remainder=True).
          prefetch(tf.data.experimental.AUTOTUNE)
         )
    

    
    return ds

In [None]:
def get_ds_size(files):
    '''
    function to get size of tfrecord Dataset, based on file name
    
    the file name has the number of records in the file, for example:
    train09-2071.tfrec has 2017 records
    
    args:
        files: list of str file names, each item should be the path to a tfrecord file
    
    returns:
        size: int, size of dataset
    '''
    size = 0
    for file in files:
        file_size = int(file.split('.tfrec')[0].split('tfrecords/')[1].split('-')[1])
        size += file_size
    return size


In [None]:
#get train and validation file paths
train_files, valid_files = train_test_split(tf.io.gfile.glob(dataset_gcs + '/train_tfrecords/*.tfrec'),
                                            test_size=.1, random_state=1)
#create datasets
train_ds = get_train_ds(train_files, params['batch_size'])
valid_ds = get_train_ds(valid_files, params['batch_size'])

In [None]:
train_size, valid_size = get_ds_size(train_files), get_ds_size(valid_files)
print('the dataset consists of: {} training images, and {} validation images'.
     format(train_size, valid_size,))

In [None]:
epoch_steps = train_size // params['batch_size'] 
valid_steps = valid_size // params['batch_size']

In [None]:
targets = train_df.groupby('label').count()['image_id'].to_list()
#todo - consider using class weights

In [None]:
targets

In [None]:
def create_model(input_shape=[*params['img_size'], 3]):
    '''
    function to create model
    '''
    input_tensor = layers.Input(shape=input_shape, name='images_input')
    
    resized = layers.experimental.preprocessing.Resizing(299, 299)(input_tensor)
    xception = tf.keras.applications.Xception(include_top=False, classes=5, input_shape=[299, 299,3])(resized)
    end_of_xception = layers.GlobalAveragePooling2D()(xception)
    dense_layers = layers.Dense(8)(end_of_xception)
    output = layers.Dense(5, activation='softmax')(dense_layers)
    model = keras.Model(inputs=input_tensor, outputs=output) 
    
    metrics = [
        keras.metrics.CategoricalAccuracy(name='accuracy'),
#         keras.metrics.SparseTopKCategoricalAccuracy(k=2, name='top_2_accuracy')
    ]
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing = 0.01),
        metrics=metrics)

    return model
    

In [None]:
with strategy.scope():
    model = create_model()

In [None]:
model.summary()

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', 
                                patience=25,
                                mode='max',
                                restore_best_weights=True)

def lr_schedule_fn(epoch, lr):
    if epoch < 8:
        return 0.000001
    elif epoch == 8:
        return 0.001
    elif epoch %2 ==0 and epoch < 49:
        return lr * 0.75
    else:
        return lr
lr_schedule = tf.keras.callbacks.LearningRateScheduler(lr_schedule_fn)

In [None]:
history = model.fit(
    train_ds,
    epochs= params['epochs'], 
    steps_per_epoch=epoch_steps,
    validation_data=valid_ds,
    validation_steps=valid_steps,
#     class_weight=class_weights,
    callbacks=[early_stopping, lr_schedule],
    verbose = 1
)