## Image Classification with Transfer Learning
### This notebook trains a CNN-based classifier to distinguish between cats and dogs.

**Dataset:**  
The dataset has been imported from Kaggle, the cats and dogs dataset

**Approach:**  
We use an object-oriented approach with TensorFlow/Keras and transfer learning 
(using EfficientNetB0) to achieve at least 90% accuracy.

**Deliverables:**
- Trained model saved as `../models/cnn_model.h5`
- This notebook (`image_classification.ipynb`)

In [32]:
# Import the necessary libraries
import numpy as np
import os
from pathlib import Path

import cv2
import tensorflow as tf
from PIL import Image
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.layers import Rescaling


def check_versions():
    """Print version information for key libraries."""
    print(f"TensorFlow version: {tf.__version__}")
    print(f"OpenCV version: {cv2.__version__}")
    print(f"Pillow (PIL) version: {Image.__version__}")


if __name__ == "__main__":
    check_versions()

TensorFlow version: 2.19.0
OpenCV version: 4.11.0
Pillow (PIL) version: 11.1.0


In [33]:
# Define the ImageClassifier class
class RobustImageClassifier:
    def __init__(self, data_dir, img_size=(224, 224), batch_size=16, val_split=0.2, seed=42):
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size  # Reduced default batch size
        self.val_split = val_split
        self.seed = seed
        self.train_ds = None
        self.val_ds = None
        self.model = None
        self.corrupted_files = []
        self._verified_images = set()

    def _validate_image_file(self, file_path):
        """Thoroughly validate an image file using both PIL and TensorFlow."""
        if file_path in self._verified_images:
            return True
            
        try:
            # First check with PIL
            with Image.open(file_path) as img:
                img.verify()  # Verify file integrity
                if len(img.getbands()) not in [1, 3]:
                    print(f"Invalid channels in {file_path}")
                    return False
                
            # Then check with TensorFlow
            img_data = tf.io.read_file(file_path)
            img = tf.image.decode_image(img_data, channels=3, expand_animations=False)
            if img.shape.rank != 3 or img.shape[-1] not in [1, 3]:
                print(f"Invalid tensor shape in {file_path}")
                return False
                
            self._verified_images.add(file_path)
            return True
            
        except Exception as e:
            print(f"Corrupted image detected: {file_path} - {str(e)}")
            self.corrupted_files.append(file_path)
            return False

    def clean_dataset(self):
        """Remove corrupted files and non-image files from dataset."""
        print("\n=== Cleaning Dataset ===")
        removed_count = 0
        valid_extensions = ('.jpg', '.jpeg', '.png')
        
        for root, _, files in os.walk(self.data_dir):
            for file in files:
                file_path = os.path.join(root, file)
                
                # Remove non-image files
                if not file.lower().endswith(valid_extensions):
                    try:
                        os.remove(file_path)
                        removed_count += 1
                        print(f"Removed non-image file: {file_path}")
                        continue
                    except Exception as e:
                        print(f"Error removing {file_path}: {str(e)}")
                        continue
                
                # Remove corrupted images
                if not self._validate_image_file(file_path):
                    try:
                        os.remove(file_path)
                        removed_count += 1
                        print(f"Removed corrupted image: {file_path}")
                    except Exception as e:
                        print(f"Error removing corrupted file {file_path}: {str(e)}")
        
        print(f"\nCleaning complete. Removed {removed_count} files.")
        print(f"Found {len(self.corrupted_files)} corrupted images.")
        return removed_count

    def _load_and_validate_image(self, file_path, label):
        """Robust image loading with comprehensive validation."""
        try:
            # Read and decode with explicit error handling
            img_data = tf.io.read_file(file_path)
            img = tf.image.decode_image(img_data, channels=3, expand_animations=False)
            
            # Validate tensor properties
            if img.shape.rank != 3 or img.shape[-1] not in [1, 3]:
                raise ValueError(f"Invalid tensor shape: {img.shape}")
                
            # Convert and resize
            img = tf.image.convert_image_dtype(img, tf.float32)
            img = tf.image.resize(img, self.img_size)
            
            # Ensure 3 channels
            if img.shape[-1] == 1:  # Grayscale
                img = tf.image.grayscale_to_rgb(img)
            elif img.shape[-1] == 4:  # RGBA
                img = img[..., :3]
                
            return img, label
            
        except Exception as e:
            print(f"Skipping corrupted image {file_path}: {str(e)}")
            # Return zero tensor that will be filtered out
            return tf.zeros((*self.img_size, 3)), label

    def _filter_valid_images(self, img, label):
        """Filter out invalid images (all zeros)."""
        return tf.reduce_sum(tf.abs(img)) > 0.0

    def load_data(self):
        """Load dataset with robust validation and error handling."""
        print("\n=== Loading Dataset ===")
        
        # First build list of validated image paths
        image_paths = []
        labels = []
        class_names = sorted(os.listdir(self.data_dir))
        label_to_index = {name: i for i, name in enumerate(class_names)}
        
        for class_name in class_names:
            class_dir = os.path.join(self.data_dir, class_name)
            for file in os.listdir(class_dir):
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    file_path = os.path.join(class_dir, file)
                    if self._validate_image_file(file_path):
                        image_paths.append(file_path)
                        labels.append(label_to_index[class_name])
        
        # Create dataset from validated paths
        path_ds = tf.data.Dataset.from_tensor_slices((image_paths, labels))
        
        # Apply loading and filtering
        image_ds = path_ds.map(
            lambda x, y: self._load_and_validate_image(x, y),
            num_parallel_calls=tf.data.AUTOTUNE
        ).filter(self._filter_valid_images)
        
        # Shuffle and split
        dataset_size = len(image_paths)
        train_size = int((1 - self.val_split) * dataset_size)
        
        self.train_ds = (image_ds.take(train_size)
                        .shuffle(1024, seed=self.seed)
                        .batch(self.batch_size)
                        .prefetch(tf.data.AUTOTUNE))
        
        self.val_ds = (image_ds.skip(train_size)
                      .batch(self.batch_size)
                      .prefetch(tf.data.AUTOTUNE))
        
        print(f"Successfully loaded {dataset_size} valid images "
              f"({train_size} training, {dataset_size - train_size} validation)")
        return self.train_ds, self.val_ds

    def build_model(self, fine_tune_at=100):
        """
        Build transfer learning model with EfficientNetB0 base.
        Includes input validation and automatic channel handling.
        """
        # Input validation
        if not isinstance(self.img_size, tuple) or len(self.img_size) != 2:
            raise ValueError("img_size must be a tuple of (height, width)")
        
        print("\n=== Building Model ===")
        try:
            base_model = EfficientNetB0(
                weights="imagenet",
                include_top=False,
                input_shape=(*self.img_size, 3)
            )
            base_model.trainable = False

            inputs = tf.keras.Input(shape=(*self.img_size, 3))
            x = layers.Rescaling(1./255)(inputs)
            x = base_model(x, training=False)
            x = layers.GlobalAveragePooling2D()(x)
            x = layers.Dropout(0.2)(x)
            outputs = layers.Dense(1, activation="sigmoid")(x)
            
            self.model = tf.keras.Model(inputs, outputs)
            
            self.model.compile(
                optimizer=optimizers.Adam(),
                loss="binary_crossentropy",
                metrics=["accuracy"]
            )
            
            print("Model built successfully!")
            return self.model
            
        except Exception as e:
            print(f"Error building model: {str(e)}")
            raise

    def fine_tune_model(self, fine_tune_at=100):
        """Unfreeze layers for fine-tuning with validation."""
        if not hasattr(self, 'model') or self.model is None:
            raise ValueError("Model must be built before fine-tuning")
            
        print("\n=== Fine-Tuning Model ===")
        try:
            base_model = self.model.layers[2]
            base_model.trainable = True
            for layer in base_model.layers[:fine_tune_at]:
                layer.trainable = False
                
            self.model.compile(
                optimizer=optimizers.Adam(1e-5),
                loss="binary_crossentropy",
                metrics=["accuracy"]
            )
            print("Model ready for fine-tuning!")
            return self.model
            
        except Exception as e:
            print(f"Error configuring fine-tuning: {str(e)}")
            raise

    def train_model(self, epochs=10):
        """Train model with comprehensive error handling and callbacks."""
        if self.train_ds is None or self.val_ds is None:
            raise ValueError("Data must be loaded before training")
            
        print("\n=== Training Model ===")
        try:
            os.makedirs("models", exist_ok=True)
            callbacks = [
                EarlyStopping(monitor="val_accuracy", patience=3, restore_best_weights=True),
                ModelCheckpoint(
                    "models/best_model.h5",
                    monitor="val_accuracy",
                    save_best_only=True,
                    save_weights_only=False
                )
            ]
            
            history = self.model.fit(
                self.train_ds,
                validation_data=self.val_ds,
                epochs=epochs,
                callbacks=callbacks,
                verbose=1
            )
            
            print("\nTraining completed successfully!")
            return history
            
        except Exception as e:
            print(f"\nTraining failed: {str(e)}")
            print("Possible solutions:")
            print("- Reduce batch size if memory error")
            print("- Check image dimensions and channels")
            print("- Verify dataset contains valid images")
            raise

    def evaluate_model(self):
        """Evaluate model with proper validation checks."""
        if self.val_ds is None:
            raise ValueError("Validation data not loaded")
            
        print("\n=== Evaluating Model ===")
        try:
            loss, accuracy = self.model.evaluate(self.val_ds, verbose=1)
            print(f"Validation Loss: {loss:.4f}")
            print(f"Validation Accuracy: {accuracy:.4f}")
            return loss, accuracy
            
        except Exception as e:
            print(f"Evaluation failed: {str(e)}")
            raise

    def save_model(self, filepath="models/cnn_model.h5"):
        """Save model with path validation."""
        try:
            os.makedirs(os.path.dirname(filepath), exist_ok=True)
            self.model.save(filepath)
            print(f"Model successfully saved to {filepath}")
        except Exception as e:
            print(f"Error saving model: {str(e)}")
            raise

    def get_corrupted_files(self):
        """Return list of detected corrupted files."""
        return self.corrupted_files

In [34]:
# Define the workflow for robust training
def main():
    """Robust training workflow with comprehensive error handling."""
    config = {
        "data_dir": "../data/pet_images",
        "img_size": (224, 224),
        "batch_size": 16,  # Reduced for stability
        "val_split": 0.2,
        "initial_epochs": 5,  # Start with fewer epochs
        "target_accuracy": 0.90,
        "model_save_path": "../models/robust_cnn_model.h5"
    }

    print("\n=== Starting Robust Training Pipeline ===")
    
    try:
        # Initialize with enhanced classifier
        classifier = RobustImageClassifier(
            data_dir=config["data_dir"],
            img_size=config["img_size"],
            batch_size=config["batch_size"],
            val_split=config["val_split"]
        )
        
        # Clean dataset thoroughly
        print("\n=== Cleaning Dataset ===")
        cleaned = classifier.clean_dataset()
        if cleaned > 0:
            print(f"Removed {cleaned} problematic files. Please verify your dataset.")
        
        # Load data with validation
        print("\n=== Loading Data ===")
        train_ds, val_ds = classifier.load_data()
        
        # Quick verification
        print("\nSample batch verification:")
        for images, labels in train_ds.take(1):
            print(f"Images shape: {images.shape}, dtype: {images.dtype}")
            print(f"Labels shape: {labels.shape}, unique values: {np.unique(labels.numpy())}")
        
        # Build and train model
        print("\n=== Building Model ===")
        classifier.build_model()
        
        print("\n=== Training Model ===")
        history = classifier.train_model(epochs=config["initial_epochs"])
        
        # Evaluation
        print("\n=== Evaluation ===")
        loss, accuracy = classifier.evaluate_model()
        
        # Save model
        print("\n=== Saving Model ===")
        classifier.save_model(config["model_save_path"])
        
        print("\n=== Pipeline Completed Successfully ===")
        return history, accuracy
        
    except Exception as e:
        print("\n=== Pipeline Failed ===")
        print(f"Error: {str(e)}")
        print("\nRecommended Actions:")
        print("1. Check the problematic_files list in the classifier")
        print("2. Manually verify the images mentioned in the error logs")
        print("3. Consider reducing batch size further if memory issues persist")
        print("4. Try with a smaller subset of data to isolate the issue")
        raise

if __name__ == "__main__":
    try:
        history, accuracy = main()
        print(f"\nFinal validation accuracy: {accuracy:.2%}")
    except Exception as e:
        print("\nTraining failed. Please address the issues and try again.")
        print(f"Last error: {str(e)}")


=== Starting Robust Training Pipeline ===

=== Cleaning Dataset ===

=== Cleaning Dataset ===
Corrupted image detected: ../data/pet_images\Cat\4351.jpg - {{function_node __wrapped__DecodeImage_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input size should match (header_size + row_size * abs_height) but they differ by 2 [Op:DecodeImage] name: 
Removed corrupted image: ../data/pet_images\Cat\4351.jpg
Corrupted image detected: ../data/pet_images\Dog\11233.jpg - {{function_node __wrapped__DecodeImage_device_/job:localhost/replica:0/task:0/device:CPU:0}} Number of channels inherent in the image must be 1, 3 or 4, was 2 [Op:DecodeImage] name: 
Removed corrupted image: ../data/pet_images\Dog\11233.jpg
Corrupted image detected: ../data/pet_images\Dog\11912.jpg - {{function_node __wrapped__DecodeImage_device_/job:localhost/replica:0/task:0/device:CPU:0}} Number of channels inherent in the image must be 1, 3 or 4, was 2 [Op:DecodeImage] name: 
Removed corrupted image: ../data/pet_image



[1m1248/1248[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m680s[0m 534ms/step - accuracy: 0.9828 - loss: 0.0593 - val_accuracy: 1.0000 - val_loss: 0.0016
Epoch 2/5
[1m1248/1248[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m623s[0m 499ms/step - accuracy: 0.9383 - loss: 0.2376 - val_accuracy: 1.0000 - val_loss: 0.0027
Epoch 3/5
[1m1248/1248[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m653s[0m 523ms/step - accuracy: 0.9386 - loss: 0.2268 - val_accuracy: 1.0000 - val_loss: 0.0020
Epoch 4/5
[1m1248/1248[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m614s[0m 491ms/step - accuracy: 0.9371 - loss: 0.2414 - val_accuracy: 1.0000 - val_loss: 0.0025

Training completed successfully!

=== Evaluation ===

=== Evaluating Model ===
[1m312/312[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m141s[0m 383ms/step - accuracy: 1.0000 - loss: 0.0016




Validation Loss: 0.0016
Validation Accuracy: 1.0000

=== Saving Model ===
Model successfully saved to ../models/robust_cnn_model.h5

=== Pipeline Completed Successfully ===

Final validation accuracy: 100.00%
