In [2]:
import os
import json
import imghdr
import matplotlib.pyplot as plt
from PIL import Image, UnidentifiedImageError
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbackas import EarlyStopping, ModelCheckpoint

  import imghdr


In [4]:
# Manually set the path to the dataset directory
images_dir = '/Users/manish/Downloads/PokemonData'  

# Verify if the dataset exists
if not os.path.exists(images_dir):
    raise FileNotFoundError(f"Dataset directory not found at: {images_dir}")
print(f"Dataset directory confirmed: {images_dir}")


Dataset directory confirmed: /Users/manish/Downloads/PokemonData


In [5]:
# Verify and clean the dataset
def check_images(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            img_path = os.path.join(root, file)
            if not is_valid_image(img_path):
                print(f"Corrupted or invalid image found: {img_path}")
                os.remove(img_path)  # Optionally delete the corrupted image

def is_valid_image(img_path):
    try:
        with Image.open(img_path) as img:
            img.verify()  # Verifies the image integrity
        return True
    except (IOError, SyntaxError, UnidentifiedImageError):
        return False

def preprocess_images(directory):
    for root, _, files in os.walk(directory):
        for file in files:
            img_path = os.path.join(root, file)

            # Check if the file is a valid image
            if imghdr.what(img_path) is None or not is_valid_image(img_path):
                print(f"Skipping invalid or corrupted file: {img_path}")
                continue  # Skip invalid files

            try:
                with Image.open(img_path) as img:
                    # If the image is PNG and has transparency, save it as PNG
                    if img.mode in ("P", "L", "LA"):
                        img = img.convert("RGBA")
                        if img.format == "PNG":  # Save as PNG to preserve transparency
                            img.save(img_path)
                        else:  # If it's not a PNG, save it in a format that doesn't support transparency (e.g., JPEG)
                            img = img.convert("RGB")  # Remove alpha channel and save as RGB
                            img.save(img_path, format="JPEG")
                    else:
                        # For images that don't have transparency (e.g., JPG), save as RGB
                        if img.mode != "RGB":
                            img = img.convert("RGB")
                        img.save(img_path)
            except Exception as e:
                print(f"Error processing image {img_path}: {e}")

# Apply preprocessing and clean the dataset
print("Checking and cleaning the dataset...")
check_images(images_dir)
print("Preprocessing images...")
preprocess_images(images_dir)
print("Preprocessing complete.")

Checking and cleaning the dataset...
Corrupted or invalid image found: /Users/manish/Downloads/PokemonData/.DS_Store
Preprocessing images...
Preprocessing complete.


In [6]:
# Initialize ImageDataGenerator
datagen = ImageDataGenerator(
    rescale=1./255,  # Normalize pixel values
    validation_split=0.2  # Split for training and validation
)

# Prepare training and validation data generators
train_gen = datagen.flow_from_directory(
    images_dir,  # Base directory
    target_size=(224, 224),  # Resize images for the model
    batch_size=32,
    class_mode='categorical',
    subset='training'
)

val_gen = datagen.flow_from_directory(
    images_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation'
)

# Log dataset stats
print(f"Training data: {train_gen.samples} images across {len(train_gen.class_indices)} classes.")
print(f"Validation data: {val_gen.samples} images across {len(val_gen.class_indices)} classes.")

Found 16133 images belonging to 151 classes.
Found 3961 images belonging to 151 classes.
Training data: 16133 images across 151 classes.
Validation data: 3961 images across 151 classes.


In [None]:
# Define CNN model
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(len(train_gen.class_indices), activation='softmax')
])

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Define callbacks with path in the writable directory
model_filename = "best_model_pokemon_custom_cnn.keras"  
callbacks = [
    EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
    ModelCheckpoint(model_filename, save_best_only=True, verbose=1)
]

# Train the model
print("Starting model training...")
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=10,
    verbose=1,
    callbacks=callbacks
)

# Save training results
results = {'model_history': history.history}
results_path = "model_results.json"  
with open(results_path, "w") as f:
    json.dump(results, f)

# Save model in .h5 format in the current directory 
model_h5_path = "pokemon_model.h5"
model.save(model_h5_path)

# Print the paths where results are saved
print(f"Training complete. Results saved to '{results_path}'. Model saved to '{model_filename}' and '{model_h5_path}'.")

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Starting model training...
Epoch 1/10


  self._warn_if_super_not_called()


[1m505/505[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 610ms/step - accuracy: 0.0174 - loss: 4.9432
Epoch 1: val_loss improved from inf to 3.81833, saving model to best_model_pokemon_custom_cnn.keras
[1m505/505[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m332s[0m 655ms/step - accuracy: 0.0175 - loss: 4.9427 - val_accuracy: 0.1444 - val_loss: 3.8183
Epoch 2/10
[1m505/505[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 664ms/step - accuracy: 0.0984 - loss: 3.9808
Epoch 2: val_loss improved from 3.81833 to 2.75882, saving model to best_model_pokemon_custom_cnn.keras
[1m505/505[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m365s[0m 723ms/step - accuracy: 0.0984 - loss: 3.9804 - val_accuracy: 0.3552 - val_loss: 2.7588
Epoch 3/10
[1m 81/505[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m6:55[0m 979ms/step - accuracy: 0.1891 - loss: 3.2648

In [None]:
# Displaying training and validation accuracy
plt.figure(figsize=(12, 6))

# Plot training & validation accuracy values
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.tight_layout()
plt.show()
