### Initial Setup and Configuration

- Importing necessary libraries and check TensorFlow version.
- Enabling GPU memory growth and set up mixed precision for faster training.

In [1]:
import os
import datetime
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.mixed_precision import set_global_policy

# Check TensorFlow version
print(tf.__version__)

# Enable GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# Set up mixed precision for faster training
set_global_policy('mixed_float16')


2.10.1
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 4070 Ti, compute capability 8.9


### Preparing the Dataset

- Loading and preparing the training and validation datasets from directory.


In [2]:
dataset_path = './dataset'  # Update this path if necessary
image_size = (224, 224)
batch_size = 32


In [3]:
train_dataset = image_dataset_from_directory(
    dataset_path,
    labels='inferred',
    label_mode='categorical',
    batch_size=batch_size,
    image_size=image_size,
    shuffle=True,
    seed=123,
    validation_split=0.2,
    subset='training',
)

validation_dataset = image_dataset_from_directory(
    dataset_path,
    labels='inferred',
    label_mode='categorical',
    batch_size=batch_size,
    image_size=image_size,
    shuffle=True,
    seed=123,
    validation_split=0.2,
    subset='validation',
)


Found 413993 files belonging to 345 classes.
Using 331195 files for training.
Found 413993 files belonging to 345 classes.
Using 82798 files for validation.


### Building the Model
- Architecture using EfficientNetB0 as the base model and compilation.

In [5]:
def build_model(num_classes):
    inputs = Input(shape=(*image_size, 3))
    base_model = EfficientNetB0(include_top=False, input_tensor=inputs, weights='imagenet')
    base_model.trainable = True  # Fine-tune the base model

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x)
    outputs = Dense(num_classes, activation='softmax', dtype=tf.float32)(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=Adam(learning_rate=1e-4),
                loss='categorical_crossentropy',
                metrics=['accuracy'])
    return model

model = build_model(num_classes)
model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 224, 224, 3)  0           ['input_1[0][0]']                
                                                                                                  
 normalization (Normalization)  (None, 224, 224, 3)  7           ['rescaling[0][0]']              
                                                                                                  
 rescaling_1 (Rescaling)        (None, 224, 224, 3)  0           ['normalization[0][0]']      

### Setting Up Callbacks and Training

- Initializing callbacks including ModelCheckpoint for saving the model during training.


In [6]:
from tensorflow.keras.callbacks import ModelCheckpoint

# Directory where the checkpoints will be saved
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Filename pattern for checkpoints
checkpoint_path = os.path.join(checkpoint_dir, "cp-{epoch:04d}.ckpt")

# Setting up the ModelCheckpoint callback
model_checkpoint = ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    verbose=1,
    save_best_only=False  # Save after every epoch
)

callbacks = [model_checkpoint]  # Add other callbacks as needed


- Started the training process and print out the first batch's shape for verification.

In [7]:
epochs = 30  # Adjust as needed

# Start training
history = model.fit(
    train_dataset,
    epochs=epochs,
    validation_data=validation_dataset,
    callbacks=callbacks
)

Epoch 1/30
Epoch 1: saving model to ./checkpoints\cp-0001.ckpt
Epoch 2/30
Epoch 2: saving model to ./checkpoints\cp-0002.ckpt
Epoch 3/30
Epoch 3: saving model to ./checkpoints\cp-0003.ckpt
Epoch 4/30
Epoch 4: saving model to ./checkpoints\cp-0004.ckpt
Epoch 5/30
Epoch 5: saving model to ./checkpoints\cp-0005.ckpt
Epoch 6/30
Epoch 6: saving model to ./checkpoints\cp-0006.ckpt
Epoch 7/30
Epoch 7: saving model to ./checkpoints\cp-0007.ckpt
Epoch 8/30
 1289/10350 [==>...........................] - ETA: 22:18 - loss: 0.4481 - accuracy: 0.8601

InvalidArgumentError: Graph execution error:

2 root error(s) found.
  (0) INVALID_ARGUMENT:  Invalid PNG data, size 23599
	 [[{{node decode_image/DecodeImage}}]]
	 [[IteratorGetNext]]
	 [[categorical_crossentropy/softmax_cross_entropy_with_logits/Shape_2/_6]]
  (1) INVALID_ARGUMENT:  Invalid PNG data, size 23599
	 [[{{node decode_image/DecodeImage}}]]
	 [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_24256]

In [8]:
print("Number of classes:", num_classes)
for images, labels in train_dataset.take(1):
    print("Batch image shape:", images.shape)
    print("Batch label shape:", labels.shape)


Number of classes: 10350
Batch image shape: (32, 224, 224, 3)
Batch label shape: (32, 345)


### Image Integrity Check

- Script to check for corrupt images or unsupported file extensions in the dataset directory.

In [8]:
from PIL import Image
import os

def check_images(s_dir, ext_list):
    bad_images = []
    bad_ext = []
    for folder in os.listdir(s_dir):
        print("Checking", folder)
        folder_path = os.path.join(s_dir, folder)
        for file in os.listdir(folder_path):
            file_path = os.path.join(folder_path, file)
            try:
                img = Image.open(file_path)  # open the image file
                img.verify()  # verify that it is, in fact, an image
            except (IOError, SyntaxError) as e:
                print('Bad file:', file_path)  # print out the names of corrupt files
                bad_images.append(file_path)
            if file.split('.')[-1].lower() not in ext_list:
                print('Bad extension:', file_path)
                bad_ext.append(file_path)
    return bad_images, bad_ext

bad_images, bad_ext = check_images(dataset_path, ['jpg', 'png'])  # Add or remove extensions based on your dataset

# Optionally remove the identified bad images
for img in bad_images:
    os.remove(img)

print(f"Removed {len(bad_images)} images and found {len(bad_ext)} images with unsupported extensions.")


Checking aircraft carrier
Checking airplane
Checking alarm clock
Checking ambulance
Checking angel
Checking animal migration
Checking ant
Checking anvil
Checking apple
Checking arm
Checking asparagus
Checking axe
Checking backpack
Checking banana
Checking bandage
Checking barn
Checking baseball
Checking baseball bat
Checking basket
Checking basketball
Checking bat
Checking bathtub
Checking beach
Checking bear
Checking beard
Checking bed
Checking bee
Checking belt
Checking bench
Checking bicycle
Checking binoculars
Checking bird
Checking birthday cake
Checking blackberry
Checking blueberry
Checking book
Checking boomerang
Checking bottlecap
Checking bowtie
Checking bracelet
Checking brain
Checking bread
Checking bridge
Checking broccoli
Checking broom
Checking bucket
Checking bulldozer
Checking bus
Checking bush
Bad file: ./dataset\bush\5484730390675456.png
Checking butterfly
Checking cactus
Checking cake
Checking calculator
Checking calendar
Checking camel
Checking camera
Checking camo

### Continuing Training from Checkpoints

- Load the latest checkpoint and continue training, adjusting the epoch numbers as needed for subsequent training sessions.

In [7]:
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    model.load_weights(latest_checkpoint)
    print("Loaded weights from:", latest_checkpoint)
else:
    print("No checkpoint found, starting training from scratch.")


Loaded weights from: ./checkpoints\cp-0035.ckpt


In [9]:
initial_epoch = 8  # This was the last epoch that was completed
total_epochs = 30  # The total number of epochs you want to train

history = model.fit(
    train_dataset,
    epochs=total_epochs,
    validation_data=validation_dataset,
    initial_epoch=initial_epoch,
    callbacks=callbacks
)

Epoch 9/30
Epoch 9: saving model to ./checkpoints\cp-0009.ckpt
Epoch 10/30
Epoch 10: saving model to ./checkpoints\cp-0010.ckpt
Epoch 11/30
Epoch 11: saving model to ./checkpoints\cp-0011.ckpt
Epoch 12/30
Epoch 12: saving model to ./checkpoints\cp-0012.ckpt
Epoch 13/30
Epoch 13: saving model to ./checkpoints\cp-0013.ckpt
Epoch 14/30
Epoch 14: saving model to ./checkpoints\cp-0014.ckpt
Epoch 15/30
Epoch 15: saving model to ./checkpoints\cp-0015.ckpt
Epoch 16/30
Epoch 16: saving model to ./checkpoints\cp-0016.ckpt
Epoch 17/30
Epoch 17: saving model to ./checkpoints\cp-0017.ckpt
Epoch 18/30
   39/10350 [..............................] - ETA: 25:53 - loss: 0.2035 - accuracy: 0.9319

KeyboardInterrupt: 

In [8]:
initial_epoch = 18  # This was the last epoch that was completed
total_epochs = 30  # The total number of epochs you want to train

history = model.fit(
    train_dataset,
    epochs=total_epochs,
    validation_data=validation_dataset,
    initial_epoch=initial_epoch,
    callbacks=callbacks
)

Epoch 19/30
Epoch 19: saving model to ./checkpoints\cp-0019.ckpt
Epoch 20/30
Epoch 20: saving model to ./checkpoints\cp-0020.ckpt
Epoch 21/30
Epoch 21: saving model to ./checkpoints\cp-0021.ckpt
Epoch 22/30
Epoch 22: saving model to ./checkpoints\cp-0022.ckpt
Epoch 23/30
Epoch 23: saving model to ./checkpoints\cp-0023.ckpt
Epoch 24/30
Epoch 24: saving model to ./checkpoints\cp-0024.ckpt
Epoch 25/30
Epoch 25: saving model to ./checkpoints\cp-0025.ckpt
Epoch 26/30
Epoch 26: saving model to ./checkpoints\cp-0026.ckpt
Epoch 27/30
Epoch 27: saving model to ./checkpoints\cp-0027.ckpt
Epoch 28/30
Epoch 28: saving model to ./checkpoints\cp-0028.ckpt
Epoch 29/30
Epoch 29: saving model to ./checkpoints\cp-0029.ckpt
Epoch 30/30
Epoch 30: saving model to ./checkpoints\cp-0030.ckpt


In [10]:
initial_epoch = 30  # This was the last epoch that was completed
total_epochs = 40  # New total number of epochs including the additional epochs

history = model.fit(
    train_dataset,
    epochs=total_epochs,
    validation_data=validation_dataset,
    initial_epoch=initial_epoch,
    callbacks=callbacks
)


Epoch 31/40
Epoch 31: saving model to ./checkpoints\cp-0031.ckpt
Epoch 32/40
Epoch 32: saving model to ./checkpoints\cp-0032.ckpt
Epoch 33/40
Epoch 33: saving model to ./checkpoints\cp-0033.ckpt
Epoch 34/40
   23/10350 [..............................] - ETA: 25:13 - loss: 0.1237 - accuracy: 0.9633

KeyboardInterrupt: 

In [8]:
initial_epoch = 34  # This was the last epoch that was completed
total_epochs = 45  # New total number of epochs including the additional epochs

history = model.fit(
    train_dataset,
    epochs=total_epochs,
    validation_data=validation_dataset,
    initial_epoch=initial_epoch,
    callbacks=callbacks
)


Epoch 35/45
Epoch 35: saving model to ./checkpoints\cp-0035.ckpt
Epoch 36/45

InvalidArgumentError: Graph execution error:

2 root error(s) found.
  (0) INVALID_ARGUMENT:  Invalid PNG data, size 12077
	 [[{{node decode_image/DecodeImage}}]]
	 [[IteratorGetNext]]
	 [[categorical_crossentropy/softmax_cross_entropy_with_logits/Shape_2/_6]]
  (1) INVALID_ARGUMENT:  Invalid PNG data, size 12077
	 [[{{node decode_image/DecodeImage}}]]
	 [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_25337]

### Model Saving and Loading

- Saved the trained model in TensorFlow's SavedModel format and demonstrate loading the saved model.

In [11]:
model_save_path = './models/doodle_recognition'  # Specify your model save path

# Save the model in TensorFlow's SavedModel format
tf.saved_model.save(model, model_save_path)




INFO:tensorflow:Assets written to: ./models/doodle_recognition\assets


INFO:tensorflow:Assets written to: ./models/doodle_recognition\assets
