In [1]:
import math, re, os
import numpy as np
import tensorflow as tf
import warnings 
warnings.filterwarnings("ignore")
from kaggle_datasets import KaggleDatasets

print("Tensorflow version " + tf.__version__)

Tensorflow version 2.11.0


In [2]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

REPLICAS:  1


In [3]:
class DataPipeline:
    
    def __init__(self, image_size=512, batch_size=16):
        self.GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
        assert image_size in (192,224,331,512)
        self.IMAGE_SIZE = [image_size, image_size]
        self.BATCH_SIZE = batch_size
        self.GCS_PATH = self.GCS_DS_PATH + f'/tfrecords-jpeg-{image_size}x{image_size}'
        self.AUTO = tf.data.experimental.AUTOTUNE
        self.TRAINING_FILENAMES = tf.io.gfile.glob(self.GCS_PATH + '/train/*.tfrec')
        self.VALIDATION_FILENAMES = tf.io.gfile.glob(self.GCS_PATH + '/val/*.tfrec')
        self.TEST_FILENAMES = tf.io.gfile.glob(self.GCS_PATH + '/test/*.tfrec') 
        
        self.CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
                        'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
                        'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
                        'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
                        'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
                        'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
                        'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
                        'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
                        'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
                        'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
                        'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 103
    
    #processing the images into floats from 0,1 and reshaping to the size required for a TPU.
    def decode_image(self, image_data):
        image = tf.image.decode_jpeg(image_data, channels=3)
        image = tf.cast(image, tf.float32) / 255.0  
        image = tf.reshape(image, [self.IMAGE_SIZE, 3]) 
        return image

    #reading the labels for my images and returns a dataset with the image and label in a pair.
    def read_labeled_tfrecord(self, example):
        LABELED_TFREC_FORMAT = {
            "image": tf.io.FixedLenFeature([], tf.string),
            "class": tf.io.FixedLenFeature([], tf.int64),  
        }
        example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
        image = self.decode_image(example['image'])
        label = tf.cast(example['class'], tf.int32)
        return image, label 
    
    #reading the unlabeled data to use for testing.
    def read_unlabeled_tfrecord(self, example):
        UNLABELED_TFREC_FORMAT = {
            "image": tf.io.FixedLenFeature([], tf.string), 
            "id": tf.io.FixedLenFeature([], tf.string),
        }
        example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
        image = self.decode_image(example['image'])
        idnum = example['id']
        return image, idnum 

    #Reading multiple files at once to improve performance. 
    #Ordering data order decreases the speed and as the data will be shuffled later on anyways. 
    def load_dataset(self, filenames, labeled=True, ordered=False):

        ignore_order = tf.data.Options()
        if not ordered:
            ignore_order.experimental_deterministic = False # disabling order

        dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=self.AUTO)
        dataset = dataset.with_options(ignore_order) 
        dataset = dataset.map(self.read_labeled_tfrecord if labeled else self.read_unlabeled_tfrecord, num_parallel_calls=self.AUTO)
        # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
        return dataset
    
    def data_augment(self, image, label):
        #data augmentation to prevent overfitting and to find more patterns.
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        image = tf.image.random_saturation(image, 0, 2)
        return image, label   

    def get_training_dataset(self, ordered=False):
        dataset = self.load_dataset(self.TRAINING_FILENAMES, labeled=True, ordered=ordered)
        dataset = dataset.repeat(10)
        dataset = dataset.map(self.data_augment, num_parallel_calls=self.AUTO)
        dataset = dataset.repeat() # the training dataset must repeat for several epochs
        dataset = dataset.shuffle(2048)
        dataset = dataset.batch(self.BATCH_SIZE)
        dataset = dataset.prefetch(self.AUTO) # get next batch while training 
        return dataset

    def get_validation_dataset(self, ordered=False):
        dataset = self.load_dataset(self.VALIDATION_FILENAMES, labeled=True, ordered=ordered)
        dataset = dataset.batch(self.BATCH_SIZE)
        dataset = dataset.cache()
        dataset = dataset.prefetch(self.AUTO)
        return dataset

    def get_test_dataset(self, ordered=False):
        dataset = self.load_dataset(self.TEST_FILENAMES, labeled=False, ordered=ordered)
        dataset = dataset.batch(self.BATCH_SIZE)
        dataset = dataset.prefetch(self.AUTO)
        return dataset

In [4]:
from tensorflow.keras.layers import Dense, Activation, Conv2D, MaxPool2D, Dropout, Flatten
gu_seed=tf.keras.initializers.GlorotUniform(seed=1)
data = DataPipeline()

with strategy.scope():
    model0 = tf.keras.Sequential()
    model0.add(Conv2D(32, kernel_size=(5,5), kernel_initializer=gu_seed, padding='same', activation='relu', input_shape=(512,512,3)))
    model0.add(MaxPool2D(pool_size=(3,3)))
    model0.add(Dropout(0.25))
    
    model0.add(Conv2D(64, kernel_size=(5,5), kernel_initializer=gu_seed, padding='same', activation='relu'))
    model0.add(MaxPool2D(pool_size=(3,3)))
    model0.add(Dropout(0.25))
    
    model0.add(Conv2D(64, kernel_size=(5,5), kernel_initializer=gu_seed, padding='same', activation='relu'))
    model0.add(MaxPool2D(pool_size=(3,3)))
    model0.add(Dropout(0.25))
    
    model0.add(Conv2D(64, kernel_size=(5,5), kernel_initializer=gu_seed, padding='same', activation='relu'))
    model0.add(MaxPool2D(pool_size=(3,3)))
    model0.add(Dropout(0.25))
    
    model0.add(Flatten())
    model0.add(Dense(len(data.CLASSES), activation='softmax'))

In [5]:
model0.compile(
    optimizer='adam',
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy'],
)

model0.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 512, 512, 32)      2432      
                                                                 
 max_pooling2d (MaxPooling2D  (None, 170, 170, 32)     0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 170, 170, 32)      0         
                                                                 
 conv2d_1 (Conv2D)           (None, 170, 170, 64)      51264     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 56, 56, 64)       0         
 2D)                                                             
                                                                 
 dropout_1 (Dropout)         (None, 56, 56, 64)        0

In [6]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(data.TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(data.VALIDATION_FILENAMES)
NUM_TEST_IMAGES = count_data_items(data.TEST_FILENAMES)

In [7]:
# Define training epochs
EPOCHS = 10

# ds_train = data.get_training_dataset()
# ds_valid = data.get_validation_dataset()
# ds_test = data.get_test_dataset()

# STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // data.BATCH_SIZE

# history = model0.fit(
#     ds_train,
#     validation_data=ds_valid,
#     epochs=EPOCHS,
#     steps_per_epoch=STEPS_PER_EPOCH
# )