# Tutorial B: Classify Flower Images with CNN

Adapted From: https://medium.com/@nutanbhogendrasharma/tensorflow-image-classification-with-tf-flowers-dataset-e36205deb8fc

In this tutorial, we will train a CNN model to classify different types of flowers. The dataset is provided by TensorFlow Datasets.

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import random
import numpy as np

import tensorflow_datasets as tfds

## Download The Dataset

In [None]:
(training_set, validation_set), dataset_info = tfds.load(
    'tf_flowers',
    split=['train[:70%]', 'train[70%:]'],
    with_info=True,
    as_supervised=True,
)
dataset_info

In [None]:
num_classes = dataset_info.features['label'].num_classes
num_classes

## Visualize The Data

Here we see several example images with their corresponding labels from the dataset.

In [None]:
fig, axs = plt.subplots(1, 3)
fig.suptitle("Example Images")

for i, (image,label) in enumerate(training_set.take(3)):
    axs[i].set_title('Label {}'.format(label))
    axs[i].imshow(image.numpy(), cmap=plt.cm.binary)

## Format The Images

Based on the visualization above, we can see that the images are differently sized. First we resize the images into a certain resolution, then all the features are scaled between 0.0 and 1.0. Finally, we also batch the training and the test set for the training process later on.

In [None]:
IMAGE_RES = 224

def format_image(image, label):
    image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0
    return image, label

BATCH_SIZE = 32

train_batches = training_set.shuffle(len(training_set)//4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)

print(train_batches)
print(validation_batches)

## Create The CNN Model

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Conv2D(112, (3,3), activation='relu', input_shape=(224, 224, 3)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5, activation='softmax')
])

In [None]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

history = model.fit(
    train_batches, 
    validation_data=validation_batches,
    epochs=5)

In [None]:
def plot_loss(history):
  plt.plot(history.history['loss'], label='loss')
  plt.plot(history.history['val_loss'], label='val_loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.legend()
  plt.grid(True)

plot_loss(history)

## Additional: View GPU Usage

In [None]:
if tf.config.list_physical_devices('GPU'):
  usage = tf.config.experimental.get_memory_info('GPU:0')
  print('Current: {:,} Byte(s)'.format(usage['current']))
  print('Peak: {:,} Byte(s)'.format(usage['peak']))