<a href="https://colab.research.google.com/github/lorihe/Springboard-Capstone3---Transfer_Learning/blob/main/Game_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras import callbacks, optimizers
from tensorflow.keras.layers import Dense, Conv2D, GlobalAvgPool2D, Input
import numpy as np
from google.colab import drive

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd drive/MyDrive/Game_Classification

/content/drive/MyDrive/Game_Classification


In [None]:
def img_Data(dir_path, target_size, batch, class_lst, preprocessing):
  """
    Create a data generator using ImageDataGenerator for loading and augmenting images from a directory.

    Args:
        dir_path (str): Path to the directory containing the image files organized in subdirectories.
        target_size (tuple): A tuple specifying the target height and width of the images.
        batch (int): The batch size for training or inference.
        class_lst (list): List of class names. Subdirectories in 'dir_path' should be named after these classes.
        preprocessing (function, optional): A function that is applied to each input image for preprocessing.
            This function should take an input image (either a PIL image or a Numpy array) and return a
            preprocessed Numpy array. Default is None.

    Returns:
        DirectoryIterator: A DirectoryIterator yielding batches of augmented/loaded image data along with labels.
  """
  if preprocessing:
    gen_object = ImageDataGenerator(preprocessing_function = preprocessing)
  else:
    gen_object = ImageDataGenerator()

  return(gen_object.flow_from_directory(dir_path, target_size, batch_size = batch, class_mode = 'sparse',
                                   classes = class_lst, shuffle = True))

In [None]:
train_data_gen = img_Data('clean_data/train', (224,224), 500, ['rugby','soccer'], preprocess_input)

Found 2448 images belonging to 2 classes.


In [None]:
valid_data_gen = img_Data('clean_data/test', (224,224), 500, ['rugby','soccer'], preprocess_input)

Found 610 images belonging to 2 classes.


In [None]:
#Use pre-trained weights to build the base model
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3), alpha=1.0, include_top=False, weights="imagenet",
    input_tensor=None, pooling=None, classes=2,
    classifier_activation="softmax"
)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5


In [None]:
base_model.trainable = False

In [None]:
#Add customized layers
from keras.layers.attention.multi_head_attention import activation
model = tf.keras.models.Sequential()
model.add(base_model)
model.add(GlobalAvgPool2D())
model.add(Dense(1024, activation='relu'))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam', loss = 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])

In [None]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobilenetv2_1.00_224 (Funct  (None, 7, 7, 1280)       2257984   
 ional)                                                          
                                                                 
 global_average_pooling2d_1   (None, 1280)             0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dense_2 (Dense)             (None, 1024)              1311744   
                                                                 
 dense_3 (Dense)             (None, 10)                10250     
                                                                 
Total params: 3,579,978
Trainable params: 1,321,994
Non-trainable params: 2,257,984
_________________________________________________________________


In [None]:
elst = callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, mode = 'min')
save_ck = callbacks.ModelCheckpoint('model.hdf5', save_best_only = True, monitor = 'val_loss', mode = 'min')

In [None]:
model.fit(train_data_gen, batch_size = 500, validation_data = valid_data_gen, callbacks = [elst, save_ck], epochs = 10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7d197ef6cd60>