# Classification on Image Data with PyTorch

In this example, we build a simple CNN to predict 10 types of flowers using `PyTorch`. Then, we build a small application that allows uploading and labeling in real time.

### Load data

Image data should be in a `zip` file and organized by one label - one folder. More specifically, all images from one label are placed in the same folder, and the folder name is the label name.

Please set `data_path` to the `zip` file in your Google Drive. The curly brackets `{}` allow us to use Python variable in a terminal command (`!unzip`) through Google Colab.

In this example, we use the flowers.zip dataset which is originally from Kaggle: https://www.kaggle.com/datasets/jonathanflorez/extended-flowers-recognition


In [None]:
data_path = 'flowers.zip'
!unzip '/content/flowers.zip'

After unzipping, the images will be stored in the `resized` folder in 10 folders representing 10 classes.

### Process data

This part can be run as is.

In [None]:
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
img_height = 256
img_width = 256
batch_size = 32
data_dir = '/content/resized'

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

#number of classes
num_classes = len(train_ds.class_names)
#dictionary mapping ids to string labels
id2label = dict()
for i, label in enumerate(train_ds.class_names):
    id2label[str(i)] = label
#save the dictionary
import pickle
with open('class_dict.dict', 'wb') as f:
  pickle.dump(id2label, f)

val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

Found 8167 files belonging to 10 classes.
Using 6534 files for training.
Found 8167 files belonging to 10 classes.
Using 1633 files for validation.


### Modeling

We can change a few hyperparamters to see if the performance improves. Save the model when you are happy with the model performance.
- `num_epochs`: like in the previous module, this is the number of iteration
- `num_cnns`: number of CNN blocks - including one Conv2D layer and one MaxPooling layer per block.
- `learning_rate`: how fast the model will update in each iteration
- `batch_size`: how many images are used in each batch in one iteration
- `weight_decay_rate`: how fast the learning rate drops while training

We will also **augment** the training images by randomly cropping, zooming, flipping, and rotating them. The augmentation is disable during inference.

Finally, we use `keras` callback to save the best model during training in terms of validation accuracy.

In [None]:
num_epochs = 10
num_cnns = 3
learning_rate = 1e-3
batch_size = 32

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

model = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1./255),
  tf.keras.layers.RandomCrop(224, 224),
  tf.keras.layers.RandomFlip("horizontal"),
  tf.keras.layers.RandomRotation(factor=0.02),
  tf.keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2),
])

for _ in range(num_cnns):
  model.add(tf.keras.layers.Conv2D(32, 3, activation='relu'))
  model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes))

model.compile(
  optimizer=tf.keras.optimizers.Adam(learning_rate),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['accuracy'])

checkpoint_filepath = 'cnn_model.keras'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=num_epochs,
  callbacks=[model_checkpoint_callback]
)

# Image Classification Application

Now we will build our application. This should be much easier since we don't need large forms like in tabular data. For applications with image inputs, we just need a button to upload image, and another for prediction.

First, load the trained model.

In [None]:
model_path = 'cnn_model.keras'

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
from google.colab import files
import ipywidgets as widgets
from IPython.display import display, clear_output
import pickle

model = tf.keras.models.load_model(model_path)
with open('class_dict.dict','rb') as f:
  id2label = pickle.load(f)

In [None]:
#button to predict
button_predict = widgets.Button(description="Predict")
#upload button
uploader = widgets.FileUpload(multiple=False)
#output
output = widgets.Output()
#display everything
display(button_predict, uploader, output)

#prediction function to attached to the predict button
@output.capture()
def on_predict_clicked(b):
  output.clear_output()
  try:
    image = Image.open(io.BytesIO(list(uploader.value.values())[0]['content']))
    image = np.array(image.convert("RGB"))
    predicted_class_id = np.argmax(model.predict(image.reshape((1,) + image.shape), verbose=False))
    label = id2label[str(predicted_class_id)]
    plt.imshow(image)
    plt.title('this image is classified as ' + label, y=-0.2)
    plt.show()
  except:
    print('please upload an image first')

button_predict.on_click(on_predict_clicked)

Button(description='Predict', style=ButtonStyle())

FileUpload(value={}, description='Upload')

Output()