In [28]:
# Import necessary libraries
import os
import json
import imghdr
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint


In [29]:
# Manually set the path to the dataset directory
images_dir = '/Users/manish/Downloads/PokemonData'  # Replace this with the actual path to your dataset


In [30]:
# Verify if the dataset exists
if not os.path.exists(images_dir):
    raise FileNotFoundError(f"Dataset directory not found at: {images_dir}")
print(f"Dataset directory confirmed: {images_dir}")

Dataset directory confirmed: /Users/manish/Downloads/PokemonData


In [31]:
# Step 1: Verify and clean the dataset
def check_images(directory):
    """
    Verifies the integrity of images in the directory and removes corrupted ones.
    """
    for root, dirs, files in os.walk(directory):
        for file in files:
            img_path = os.path.join(root, file)
            try:
                with Image.open(img_path) as img:
                    img.verify()  # Verifies the image integrity
            except (IOError, SyntaxError) as e:
                print(f"Corrupted image found: {img_path}")
                os.remove(img_path)  # Optionally delete the corrupted image

def convert_images_to_rgb(directory):
    """
    Converts all images to RGB format (if not already in RGB format).
    """
    for root, dirs, files in os.walk(directory):
        for file in files:
            img_path = os.path.join(root, file)
            try:
                with Image.open(img_path) as img:
                    if img.mode != 'RGB':  # Convert non-RGB images to RGB
                        img = img.convert('RGB')
                    img.save(img_path)  # Overwrite the image in the same directory
            except Exception as e:
                print(f"Error processing image {img_path}: {e}")

In [32]:
def preprocess_images(directory):
    """
    Ensures all images in the directory are in RGB format.
    Converts palette-based images with transparency to RGBA, and saves them as PNG.
    JPEG images are saved without any transparency.
    """
    for root, _, files in os.walk(directory):
        for file in files:
            img_path = os.path.join(root, file)

            # Check if the file is a valid image
            if imghdr.what(img_path) is None:
                print(f"Skipping invalid file: {img_path}")
                continue  # Skip invalid files

            try:
                with Image.open(img_path) as img:
                    # If the image is PNG and has transparency, save it as PNG
                    if img.mode in ("P", "L", "LA"):
                        img = img.convert("RGBA")
                        if img.format == "PNG":  # Save as PNG to preserve transparency
                            img.save(img_path)
                        else:  # If it's not a PNG, save it in a format that doesn't support transparency (e.g., JPEG)
                            img = img.convert("RGB")  # Remove alpha channel and save as RGB
                            img.save(img_path, format="JPEG")
                    else:
                        # For images that don't have transparency (e.g., JPG), save as RGB
                        if img.mode != "RGB":
                            img = img.convert("RGB")
                        img.save(img_path)
            except Exception as e:
                print(f"Error processing image {img_path}: {e}")

# Apply preprocessing
print("Preprocessing images...")
preprocess_images(images_dir)
print("Preprocessing complete.")


Preprocessing images...
Skipping invalid file: /Users/manish/Downloads/PokemonData/Venusaur/e20dd0c9dbae4a299b32be5f486e4143.jpg
Skipping invalid file: /Users/manish/Downloads/PokemonData/Onix/Onix58.jpg
Skipping invalid file: /Users/manish/Downloads/PokemonData/Onix/a445ddff0c6640a381e2da954f117e88.jpg
Skipping invalid file: /Users/manish/Downloads/PokemonData/Onix/Onix46.jpg
Skipping invalid file: /Users/manish/Downloads/PokemonData/Seadra/a9e0e66523a5477fb8cb58b57d89af24.jpg
Preprocessing complete.


In [33]:
# Initialize ImageDataGenerator
datagen = ImageDataGenerator(
    rescale=1./255,  # Normalize pixel values
    validation_split=0.2  # Split for training and validation
)

# Prepare training and validation data generators
train_gen = datagen.flow_from_directory(
    images_dir,  # Base directory
    target_size=(224, 224),  # Resize images for the model
    batch_size=32,
    class_mode='categorical',
    subset='training'
)

val_gen = datagen.flow_from_directory(
    images_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation'
)

# Log dataset stats
print(f"Training data: {train_gen.samples} images across {len(train_gen.class_indices)} classes.")
print(f"Validation data: {val_gen.samples} images across {len(val_gen.class_indices)} classes.")


Found 16137 images belonging to 151 classes.
Found 3962 images belonging to 151 classes.
Training data: 16137 images across 151 classes.
Validation data: 3962 images across 151 classes.


In [34]:
# Define CNN model
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
    MaxPooling2D(2, 2),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(len(train_gen.class_indices), activation='softmax')
])

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [35]:
# Define callbacks with path in the writable directory
model_filename = "best_model_pokemon_custom_cnn.keras"  # Save the model to the current directory or specify another path
callbacks = [
    EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
    ModelCheckpoint(model_filename, save_best_only=True, verbose=1)
]

# Train the model
print("Starting model training...")
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=10,
    verbose=1,
    callbacks=callbacks
)

# Save training results
results = {'model_history': history.history}
results_path = "model_results.json"  # Save results to the current directory or specify another path
with open(results_path, "w") as f:
    json.dump(results, f)

# Save model in .h5 format in the current directory or specify another path
model_h5_path = "pokemon_model.h5"
model.save(model_h5_path)

# Print the paths where results are saved
print(f"Training complete. Results saved to '{results_path}'. Model saved to '{model_filename}' and '{model_h5_path}'.")


Starting model training...
Epoch 1/10


  self._warn_if_super_not_called()


[1m144/505[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m3:31[0m 586ms/step - accuracy: 0.0082 - loss: 5.0799

2024-11-18 02:16:03.002426: W tensorflow/core/framework/op_kernel.cc:1828] UNKNOWN: UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x15d3bf970>
Traceback (most recent call last):

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)
          ^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/tensorflow/python/data/ops/from_generator_op.py", line 198, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py", line 260, in _get_iterator
    fo

[1m145/505[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m3:30[0m 586ms/step - accuracy: 0.0082 - loss: 5.0795

UnknownError: Graph execution error:

Detected at node PyFunc defined at (most recent call last):
<stack traces unavailable>
UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x15d3bf970>
Traceback (most recent call last):

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)
          ^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/tensorflow/python/data/ops/from_generator_op.py", line 198, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py", line 260, in _get_iterator
    for i, batch in enumerate(gen_fn()):

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py", line 253, in generator_fn
    yield self.py_dataset[i]
          ~~~~~~~~~~~~~~~^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/keras/src/legacy/preprocessing/image.py", line 68, in __getitem__
    return self._get_batches_of_transformed_samples(index_array)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/keras/src/legacy/preprocessing/image.py", line 313, in _get_batches_of_transformed_samples
    img = image_utils.load_img(
          ^^^^^^^^^^^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/keras/src/utils/image_utils.py", line 236, in load_img
    img = pil_image.open(io.BytesIO(f.read()))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/Users/manish/tensorflow_env/lib/python3.12/site-packages/PIL/Image.py", line 3498, in open
    raise UnidentifiedImageError(msg)

PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x15d3bf970>


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]] [Op:__inference_one_step_on_iterator_7377]

In [21]:
# Displaying training and validation accuracy
plt.figure(figsize=(12, 6))

# Plot training & validation accuracy values
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.tight_layout()
plt.show()


NameError: name 'plt' is not defined